diff --git a/run.py b/run.py index ffb66f2..16aaf9c 100644 --- a/run.py +++ b/run.py @@ -1,17 +1,14 @@ import os -import shutil -from torchview import draw_graph - # 检查数据集完整性 from lib.Download_data import check_and_download_data + data_complete = check_and_download_data() assert data_complete is not None, "数据集下载失败,请重试!" import torch from datetime import datetime # import time - from config.args_parser import parse_args from lib.initializer import init_model, init_optimizer from lib.loss_function import get_loss_function @@ -21,32 +18,41 @@ from trainer.trainer_selector import select_trainer import yaml + + def main(): args = parse_args() # Set device - if torch.cuda.is_available(): + if torch.cuda.is_available() and args['device'] != 'cpu': torch.cuda.set_device(int(args['device'].split(':')[1])) args['model']['device'] = args['device'] else: args['device'] = 'cpu' args['model']['device'] = args['device'] - # Initialize model model = init_model(args['model'], device=args['device']) - if args['mode'] == "draw": - dummy_input = torch.randn(64,12,307,3) - model_graph = draw_graph(model, - input_data = dummy_input, - device=args['device'], - show_shapes=True, - save_graph=True, - graph_name=f"{args['model']['type']}_graph", - directory="./", - format="png" - ) + if args['mode'] == "benchmark": + # 支持计算消耗分析,设置 mode为 benchmark + import torch.profiler as profiler + dummy_input = torch.randn((64, 12, args['model']['num_nodes'], 3), device=args['device']) + min_val = dummy_input.min(dim=-1, keepdim=True)[0] + max_val = dummy_input.max(dim=-1, keepdim=True)[0] + + dummy_input = (dummy_input - min_val) / (max_val - min_val + 1e-6) + with profiler.profile( + activities=[ + profiler.ProfilerActivity.CPU, + profiler.ProfilerActivity.CUDA + ], + with_stack=True, + profile_memory=True, + record_shapes=True + ) as prof: + out = model(dummy_input) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) return 0 # Load dataset @@ -85,7 +91,7 @@ def main(): # Start training or testing trainer = select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, - lr_scheduler, extra_data) + lr_scheduler, extra_data) match args['mode']: case 'train':