新增消耗分析模式,只需在原有的mode中调整为benchmark即可

This commit is contained in:
czzhangheng 2025-03-26 12:38:57 +08:00
parent dcd4f99f9c
commit 5306d24408
1 changed files with 24 additions and 18 deletions

42
run.py
View File

@ -1,17 +1,14 @@
import os
import shutil
from torchview import draw_graph
# 检查数据集完整性
from lib.Download_data import check_and_download_data
data_complete = check_and_download_data()
assert data_complete is not None, "数据集下载失败,请重试!"
import torch
from datetime import datetime
# import time
from config.args_parser import parse_args
from lib.initializer import init_model, init_optimizer
from lib.loss_function import get_loss_function
@ -21,32 +18,41 @@ from trainer.trainer_selector import select_trainer
import yaml
def main():
args = parse_args()
# Set device
if torch.cuda.is_available():
if torch.cuda.is_available() and args['device'] != 'cpu':
torch.cuda.set_device(int(args['device'].split(':')[1]))
args['model']['device'] = args['device']
else:
args['device'] = 'cpu'
args['model']['device'] = args['device']
# Initialize model
model = init_model(args['model'], device=args['device'])
if args['mode'] == "draw":
dummy_input = torch.randn(64,12,307,3)
model_graph = draw_graph(model,
input_data = dummy_input,
device=args['device'],
show_shapes=True,
save_graph=True,
graph_name=f"{args['model']['type']}_graph",
directory="./",
format="png"
)
if args['mode'] == "benchmark":
# 支持计算消耗分析,设置 mode为 benchmark
import torch.profiler as profiler
dummy_input = torch.randn((64, 12, args['model']['num_nodes'], 3), device=args['device'])
min_val = dummy_input.min(dim=-1, keepdim=True)[0]
max_val = dummy_input.max(dim=-1, keepdim=True)[0]
dummy_input = (dummy_input - min_val) / (max_val - min_val + 1e-6)
with profiler.profile(
activities=[
profiler.ProfilerActivity.CPU,
profiler.ProfilerActivity.CUDA
],
with_stack=True,
profile_memory=True,
record_shapes=True
) as prof:
out = model(dummy_input)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
return 0
# Load dataset
@ -85,7 +91,7 @@ def main():
# Start training or testing
trainer = select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
lr_scheduler, extra_data)
lr_scheduler, extra_data)
match args['mode']:
case 'train':