新增消耗分析模式,只需在原有的mode中调整为benchmark即可
This commit is contained in:
parent
dcd4f99f9c
commit
5306d24408
40
run.py
40
run.py
|
|
@ -1,17 +1,14 @@
|
|||
import os
|
||||
import shutil
|
||||
from torchview import draw_graph
|
||||
|
||||
|
||||
# 检查数据集完整性
|
||||
from lib.Download_data import check_and_download_data
|
||||
|
||||
data_complete = check_and_download_data()
|
||||
assert data_complete is not None, "数据集下载失败,请重试!"
|
||||
|
||||
import torch
|
||||
from datetime import datetime
|
||||
# import time
|
||||
|
||||
from config.args_parser import parse_args
|
||||
from lib.initializer import init_model, init_optimizer
|
||||
from lib.loss_function import get_loss_function
|
||||
|
|
@ -21,32 +18,41 @@ from trainer.trainer_selector import select_trainer
|
|||
import yaml
|
||||
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Set device
|
||||
if torch.cuda.is_available():
|
||||
if torch.cuda.is_available() and args['device'] != 'cpu':
|
||||
torch.cuda.set_device(int(args['device'].split(':')[1]))
|
||||
args['model']['device'] = args['device']
|
||||
else:
|
||||
args['device'] = 'cpu'
|
||||
args['model']['device'] = args['device']
|
||||
|
||||
|
||||
# Initialize model
|
||||
model = init_model(args['model'], device=args['device'])
|
||||
|
||||
if args['mode'] == "draw":
|
||||
dummy_input = torch.randn(64,12,307,3)
|
||||
model_graph = draw_graph(model,
|
||||
input_data = dummy_input,
|
||||
device=args['device'],
|
||||
show_shapes=True,
|
||||
save_graph=True,
|
||||
graph_name=f"{args['model']['type']}_graph",
|
||||
directory="./",
|
||||
format="png"
|
||||
)
|
||||
if args['mode'] == "benchmark":
|
||||
# 支持计算消耗分析,设置 mode为 benchmark
|
||||
import torch.profiler as profiler
|
||||
dummy_input = torch.randn((64, 12, args['model']['num_nodes'], 3), device=args['device'])
|
||||
min_val = dummy_input.min(dim=-1, keepdim=True)[0]
|
||||
max_val = dummy_input.max(dim=-1, keepdim=True)[0]
|
||||
|
||||
dummy_input = (dummy_input - min_val) / (max_val - min_val + 1e-6)
|
||||
with profiler.profile(
|
||||
activities=[
|
||||
profiler.ProfilerActivity.CPU,
|
||||
profiler.ProfilerActivity.CUDA
|
||||
],
|
||||
with_stack=True,
|
||||
profile_memory=True,
|
||||
record_shapes=True
|
||||
) as prof:
|
||||
out = model(dummy_input)
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||
return 0
|
||||
|
||||
# Load dataset
|
||||
|
|
|
|||
Loading…
Reference in New Issue