#!/usr/bin/env python3 """ 时空数据深度学习预测项目主程序 专门处理时空数据格式 (batch_size, seq_len, num_nodes, features) """ import os from utils.args_reader import config_loader import utils.init as init import torch def main(): config = config_loader() device = config['basic']['device'] = init.device(config['basic']['device']) init.seed(config['basic']['seed']) model = init.model(config) train_loader, val_loader, test_loader, scaler = init.dataloader(config) loss = init.loss(config, scaler) optim, lr = init.optimizer(config, model) logger = init.Logger(config) trainer = init.trainer(config, model, loss, optim, train_loader, val_loader, test_loader, scaler, logger, lr) match config['basic']['mode']: case 'train': trainer.train() case 'test': params_path = f"./pre-trained/{config['basic']['model']}/{config['basic']['dataset']}.pth" params = torch.load(params_path, map_location=device, weights_only=True) model.load_state_dict(params) trainer.test(model.to(device), config, test_loader, scaler, trainer.logger) if __name__ == "__main__": main()