From 5306d244081716258786d2b1ba50ea70fd9cdee3 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 26 Mar 2025 12:38:57 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E6=B6=88=E8=80=97=E5=88=86?= =?UTF-8?q?=E6=9E=90=E6=A8=A1=E5=BC=8F=EF=BC=8C=E5=8F=AA=E9=9C=80=E5=9C=A8?= =?UTF-8?q?=E5=8E=9F=E6=9C=89=E7=9A=84mode=E4=B8=AD=E8=B0=83=E6=95=B4?= =?UTF-8?q?=E4=B8=BAbenchmark=E5=8D=B3=E5=8F=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- run.py | 42 ++++++++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/run.py b/run.py index ffb66f2..16aaf9c 100644 --- a/run.py +++ b/run.py @@ -1,17 +1,14 @@ import os -import shutil -from torchview import draw_graph - # 检查数据集完整性 from lib.Download_data import check_and_download_data + data_complete = check_and_download_data() assert data_complete is not None, "数据集下载失败,请重试!" import torch from datetime import datetime # import time - from config.args_parser import parse_args from lib.initializer import init_model, init_optimizer from lib.loss_function import get_loss_function @@ -21,32 +18,41 @@ from trainer.trainer_selector import select_trainer import yaml + + def main(): args = parse_args() # Set device - if torch.cuda.is_available(): + if torch.cuda.is_available() and args['device'] != 'cpu': torch.cuda.set_device(int(args['device'].split(':')[1])) args['model']['device'] = args['device'] else: args['device'] = 'cpu' args['model']['device'] = args['device'] - # Initialize model model = init_model(args['model'], device=args['device']) - if args['mode'] == "draw": - dummy_input = torch.randn(64,12,307,3) - model_graph = draw_graph(model, - input_data = dummy_input, - device=args['device'], - show_shapes=True, - save_graph=True, - graph_name=f"{args['model']['type']}_graph", - directory="./", - format="png" - ) + if args['mode'] == "benchmark": + # 支持计算消耗分析,设置 mode为 benchmark + import torch.profiler as profiler + dummy_input = torch.randn((64, 12, args['model']['num_nodes'], 3), device=args['device']) + min_val = dummy_input.min(dim=-1, keepdim=True)[0] + max_val = dummy_input.max(dim=-1, keepdim=True)[0] + + dummy_input = (dummy_input - min_val) / (max_val - min_val + 1e-6) + with profiler.profile( + activities=[ + profiler.ProfilerActivity.CPU, + profiler.ProfilerActivity.CUDA + ], + with_stack=True, + profile_memory=True, + record_shapes=True + ) as prof: + out = model(dummy_input) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) return 0 # Load dataset @@ -85,7 +91,7 @@ def main(): # Start training or testing trainer = select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, - lr_scheduler, extra_data) + lr_scheduler, extra_data) match args['mode']: case 'train':