36 lines
1.2 KiB
Python
36 lines
1.2 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
时空数据深度学习预测项目主程序
|
|
专门处理时空数据格式 (batch_size, seq_len, num_nodes, features)
|
|
"""
|
|
|
|
import os
|
|
from utils.args_reader import config_loader
|
|
import utils.init as init
|
|
import torch
|
|
|
|
|
|
def main():
|
|
config = config_loader()
|
|
device = config['basic']['device'] = init.device(config['basic']['device'])
|
|
init.seed(config['basic']['seed'])
|
|
model = init.model(config)
|
|
train_loader, val_loader, test_loader, scaler = init.dataloader(config)
|
|
loss = init.loss(config, scaler)
|
|
optim, lr = init.optimizer(config, model)
|
|
logger = init.Logger(config)
|
|
trainer = init.trainer(config, model, loss, optim, train_loader, val_loader, test_loader, scaler, logger, lr)
|
|
match config['basic']['mode']:
|
|
case 'train':
|
|
trainer.train()
|
|
case 'test':
|
|
params_path = f"./pre-trained/{config['basic']['model']}/{config['basic']['dataset']}.pth"
|
|
params = torch.load(params_path, map_location=device, weights_only=True)
|
|
model.load_state_dict(params)
|
|
trainer.test(model.to(device), config, test_loader, scaler, trainer.logger)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|