import os # 检查数据集完整性 from lib.Download_data import check_and_download_data data_complete = check_and_download_data() assert data_complete is not None, "数据集下载失败,请重试!" import torch from datetime import datetime # import time from config.args_parser import parse_args from lib.initializer import init_model, init_optimizer, init_seed from lib.loss_function import get_loss_function from dataloader.loader_selector import get_dataloader from trainer.trainer_selector import select_trainer import yaml def main(): args = parse_args() # Set device if torch.cuda.is_available() and args['device'] != 'cpu': torch.cuda.set_device(int(args['device'].split(':')[1])) args['model']['device'] = args['device'] else: args['device'] = 'cpu' args['model']['device'] = args['device'] init_seed(args['train']['seed']) # Initialize model model = init_model(args['model'], device=args['device']) # if args['mode'] == "benchmark": # # 支持计算消耗分析,设置 mode为 benchmark # import torch.profiler as profiler # dummy_input = torch.randn((64, 12, args['model']['num_nodes'], 3), device=args['device']) # min_val = dummy_input.min(dim=-1, keepdim=True)[0] # max_val = dummy_input.max(dim=-1, keepdim=True)[0] # # dummy_input = (dummy_input - min_val) / (max_val - min_val + 1e-6) # with profiler.profile( # activities=[ # profiler.ProfilerActivity.CPU, # profiler.ProfilerActivity.CUDA # ], # with_stack=True, # profile_memory=True, # record_shapes=True # ) as prof: # out = model(dummy_input) # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) # return 0 # Load dataset train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader( args, normalizer=args['data']['normalizer'], single=False ) # Initialize loss function loss = get_loss_function(args['train'], scaler) # Initialize optimizer and learning rate scheduler optimizer, lr_scheduler = init_optimizer(model, args['train']) # Configure log path current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') current_dir = os.path.dirname(os.path.realpath(__file__)) args['train']['log_dir'] = os.path.join(current_dir, 'experiments', args['data']['type'], current_time) # 配置文件路径 config_filename = f"{args['data']['type']}.yaml" # config_path = os.path.join(current_dir, 'config', args['model']['type'], config_filename) # 确保日志目录存在 os.makedirs(args['train']['log_dir'], exist_ok=True) # 生成配置文件内容(将 args 转换为 YAML 格式) config_content = yaml.safe_dump(args, default_flow_style=False) # 生成新的 YAML 文件名(例如:config.auto.yaml 或其他名称) destination_path = os.path.join(args['train']['log_dir'], config_filename) # 将 args 保存为 YAML 文件 with open(destination_path, 'w') as f: f.write(config_content) # Start training or testing trainer = select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, extra_data) match args['mode']: case 'train': trainer.train() case 'test': model.load_state_dict(torch.load( f"./pre-trained/{args['model']['type']}/{args['data']['type']}.pth", map_location=args['device'], weights_only=True)) # print(f"Loaded saved model on {args['device']}") trainer.test(model.to(args['device']), trainer.args, test_loader, scaler, trainer.logger) case _: raise ValueError(f"Unsupported mode: {args['mode']}") if __name__ == '__main__': main()