新增消耗分析模式,只需在原有的mode中调整为benchmark即可
This commit is contained in:
parent
dcd4f99f9c
commit
5306d24408
42
run.py
42
run.py
|
|
@ -1,17 +1,14 @@
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
from torchview import draw_graph
|
|
||||||
|
|
||||||
|
|
||||||
# 检查数据集完整性
|
# 检查数据集完整性
|
||||||
from lib.Download_data import check_and_download_data
|
from lib.Download_data import check_and_download_data
|
||||||
|
|
||||||
data_complete = check_and_download_data()
|
data_complete = check_and_download_data()
|
||||||
assert data_complete is not None, "数据集下载失败,请重试!"
|
assert data_complete is not None, "数据集下载失败,请重试!"
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
# import time
|
# import time
|
||||||
|
|
||||||
from config.args_parser import parse_args
|
from config.args_parser import parse_args
|
||||||
from lib.initializer import init_model, init_optimizer
|
from lib.initializer import init_model, init_optimizer
|
||||||
from lib.loss_function import get_loss_function
|
from lib.loss_function import get_loss_function
|
||||||
|
|
@ -21,32 +18,41 @@ from trainer.trainer_selector import select_trainer
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# Set device
|
# Set device
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available() and args['device'] != 'cpu':
|
||||||
torch.cuda.set_device(int(args['device'].split(':')[1]))
|
torch.cuda.set_device(int(args['device'].split(':')[1]))
|
||||||
args['model']['device'] = args['device']
|
args['model']['device'] = args['device']
|
||||||
else:
|
else:
|
||||||
args['device'] = 'cpu'
|
args['device'] = 'cpu'
|
||||||
args['model']['device'] = args['device']
|
args['model']['device'] = args['device']
|
||||||
|
|
||||||
|
|
||||||
# Initialize model
|
# Initialize model
|
||||||
model = init_model(args['model'], device=args['device'])
|
model = init_model(args['model'], device=args['device'])
|
||||||
|
|
||||||
if args['mode'] == "draw":
|
if args['mode'] == "benchmark":
|
||||||
dummy_input = torch.randn(64,12,307,3)
|
# 支持计算消耗分析,设置 mode为 benchmark
|
||||||
model_graph = draw_graph(model,
|
import torch.profiler as profiler
|
||||||
input_data = dummy_input,
|
dummy_input = torch.randn((64, 12, args['model']['num_nodes'], 3), device=args['device'])
|
||||||
device=args['device'],
|
min_val = dummy_input.min(dim=-1, keepdim=True)[0]
|
||||||
show_shapes=True,
|
max_val = dummy_input.max(dim=-1, keepdim=True)[0]
|
||||||
save_graph=True,
|
|
||||||
graph_name=f"{args['model']['type']}_graph",
|
dummy_input = (dummy_input - min_val) / (max_val - min_val + 1e-6)
|
||||||
directory="./",
|
with profiler.profile(
|
||||||
format="png"
|
activities=[
|
||||||
)
|
profiler.ProfilerActivity.CPU,
|
||||||
|
profiler.ProfilerActivity.CUDA
|
||||||
|
],
|
||||||
|
with_stack=True,
|
||||||
|
profile_memory=True,
|
||||||
|
record_shapes=True
|
||||||
|
) as prof:
|
||||||
|
out = model(dummy_input)
|
||||||
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# Load dataset
|
# Load dataset
|
||||||
|
|
@ -85,7 +91,7 @@ def main():
|
||||||
|
|
||||||
# Start training or testing
|
# Start training or testing
|
||||||
trainer = select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
|
trainer = select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
|
||||||
lr_scheduler, extra_data)
|
lr_scheduler, extra_data)
|
||||||
|
|
||||||
match args['mode']:
|
match args['mode']:
|
||||||
case 'train':
|
case 'train':
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue