新增消耗分析模式,只需在原有的mode中调整为benchmark即可

This commit is contained in:
czzhangheng 2025-03-26 12:38:57 +08:00
parent dcd4f99f9c
commit 5306d24408
1 changed files with 24 additions and 18 deletions

40
run.py
View File

@ -1,17 +1,14 @@
import os import os
import shutil
from torchview import draw_graph
# 检查数据集完整性 # 检查数据集完整性
from lib.Download_data import check_and_download_data from lib.Download_data import check_and_download_data
data_complete = check_and_download_data() data_complete = check_and_download_data()
assert data_complete is not None, "数据集下载失败,请重试!" assert data_complete is not None, "数据集下载失败,请重试!"
import torch import torch
from datetime import datetime from datetime import datetime
# import time # import time
from config.args_parser import parse_args from config.args_parser import parse_args
from lib.initializer import init_model, init_optimizer from lib.initializer import init_model, init_optimizer
from lib.loss_function import get_loss_function from lib.loss_function import get_loss_function
@ -21,32 +18,41 @@ from trainer.trainer_selector import select_trainer
import yaml import yaml
def main(): def main():
args = parse_args() args = parse_args()
# Set device # Set device
if torch.cuda.is_available(): if torch.cuda.is_available() and args['device'] != 'cpu':
torch.cuda.set_device(int(args['device'].split(':')[1])) torch.cuda.set_device(int(args['device'].split(':')[1]))
args['model']['device'] = args['device'] args['model']['device'] = args['device']
else: else:
args['device'] = 'cpu' args['device'] = 'cpu'
args['model']['device'] = args['device'] args['model']['device'] = args['device']
# Initialize model # Initialize model
model = init_model(args['model'], device=args['device']) model = init_model(args['model'], device=args['device'])
if args['mode'] == "draw": if args['mode'] == "benchmark":
dummy_input = torch.randn(64,12,307,3) # 支持计算消耗分析,设置 mode为 benchmark
model_graph = draw_graph(model, import torch.profiler as profiler
input_data = dummy_input, dummy_input = torch.randn((64, 12, args['model']['num_nodes'], 3), device=args['device'])
device=args['device'], min_val = dummy_input.min(dim=-1, keepdim=True)[0]
show_shapes=True, max_val = dummy_input.max(dim=-1, keepdim=True)[0]
save_graph=True,
graph_name=f"{args['model']['type']}_graph", dummy_input = (dummy_input - min_val) / (max_val - min_val + 1e-6)
directory="./", with profiler.profile(
format="png" activities=[
) profiler.ProfilerActivity.CPU,
profiler.ProfilerActivity.CUDA
],
with_stack=True,
profile_memory=True,
record_shapes=True
) as prof:
out = model(dummy_input)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
return 0 return 0
# Load dataset # Load dataset