Project-I/main.py

36 lines
1.2 KiB
Python

#!/usr/bin/env python3
"""
时空数据深度学习预测项目主程序
专门处理时空数据格式 (batch_size, seq_len, num_nodes, features)
"""
import os
from utils.args_reader import config_loader
import utils.init as init
import torch
def main():
config = config_loader()
device = config['basic']['device'] = init.device(config['basic']['device'])
init.seed(config['basic']['seed'])
model = init.model(config)
train_loader, val_loader, test_loader, scaler = init.dataloader(config)
loss = init.loss(config, scaler)
optim, lr = init.optimizer(config, model)
logger = init.Logger(config)
trainer = init.trainer(config, model, loss, optim, train_loader, val_loader, test_loader, scaler, logger, lr)
match config['basic']['mode']:
case 'train':
trainer.train()
case 'test':
params_path = f"./pre-trained/{config['basic']['model']}/{config['basic']['dataset']}.pth"
params = torch.load(params_path, map_location=device, weights_only=True)
model.load_state_dict(params)
trainer.test(model.to(device), config, test_loader, scaler, trainer.logger)
if __name__ == "__main__":
main()