import os # 检查数据集完整性 from lib.Download_data import check_and_download_data data_complete = check_and_download_data() assert data_complete is not None, "数据集下载失败,请重试!" import torch from datetime import datetime # import time from config.args_parser import parse_args from lib.initializer import init_model, init_optimizer from lib.loss_function import get_loss_function from dataloader.loader_selector import get_dataloader from trainer.trainer_selector import select_trainer import yaml def main(): args = parse_args() # Set device if torch.cuda.is_available() and args['device'] != 'cpu': torch.cuda.set_device(int(args['device'].split(':')[1])) args['model']['device'] = args['device'] else: args['device'] = 'cpu' args['model']['device'] = args['device'] # Initialize model model = init_model(args['model'], device=args['device']) if args['mode'] == "benchmark": # 支持计算消耗分析,设置 mode为 benchmark import torch.profiler as profiler dummy_input = torch.randn((64, 12, args['model']['num_nodes'], 3), device=args['device']) min_val = dummy_input.min(dim=-1, keepdim=True)[0] max_val = dummy_input.max(dim=-1, keepdim=True)[0] dummy_input = (dummy_input - min_val) / (max_val - min_val + 1e-6) with profiler.profile( activities=[ profiler.ProfilerActivity.CPU, profiler.ProfilerActivity.CUDA ], with_stack=True, profile_memory=True, record_shapes=True ) as prof: out = model(dummy_input) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) return 0 # Load dataset train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( args, normalizer=args['data']['normalizer'], single=False ) # Initialize loss function loss = get_loss_function(args['train'], scaler) # Initialize optimizer and learning rate scheduler optimizer, lr_scheduler = init_optimizer(model, args['train']) # Configure log path current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') current_dir = os.path.dirname(os.path.realpath(__file__)) args['train']['log_dir'] = os.path.join(current_dir, 'experiments', args['data']['type'], current_time) # 配置文件路径 config_filename = f"{args['data']['type']}.yaml" # config_path = os.path.join(current_dir, 'config', args['model']['type'], config_filename) # 确保日志目录存在 os.makedirs(args['train']['log_dir'], exist_ok=True) # 生成配置文件内容(将 args 转换为 YAML 格式) config_content = yaml.safe_dump(args, default_flow_style=False) # 生成新的 YAML 文件名(例如:config.auto.yaml 或其他名称) destination_path = os.path.join(args['train']['log_dir'], config_filename) # 将 args 保存为 YAML 文件 with open(destination_path, 'w') as f: f.write(config_content) # Start training or testing trainer = select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, extra_data) match args['mode']: case 'train': trainer.train() case 'test': model.load_state_dict(torch.load( f"./pre-trained/{args['model']['type']}/{args['data']['type']}.pth", map_location=args['device'], weights_only=True)) # print(f"Loaded saved model on {args['device']}") trainer.test(model.to(args['device']), trainer.args, test_loader, scaler, trainer.logger) case _: raise ValueError(f"Unsupported mode: {args['mode']}") if __name__ == '__main__': main()