TrafficWheel/run.py

113 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
# 检查数据集完整性
from lib.Download_data import check_and_download_data
data_complete = check_and_download_data()
assert data_complete is not None, "数据集下载失败,请重试!"
import torch
from datetime import datetime
# import time
from config.args_parser import parse_args
from lib.initializer import init_model, init_optimizer, init_seed
from lib.loss_function import get_loss_function
from dataloader.loader_selector import get_dataloader
from trainer.trainer_selector import select_trainer
import yaml
def main():
args = parse_args()
# Set device
if torch.cuda.is_available() and args['device'] != 'cpu':
torch.cuda.set_device(int(args['device'].split(':')[1]))
args['model']['device'] = args['device']
else:
args['device'] = 'cpu'
args['model']['device'] = args['device']
init_seed(args['train']['seed'])
# Initialize model
model = init_model(args['model'], device=args['device'])
if args['mode'] == "benchmark":
# 支持计算消耗分析,设置 mode为 benchmark
import torch.profiler as profiler
dummy_input = torch.randn((64, 12, args['model']['num_nodes'], 3), device=args['device'])
min_val = dummy_input.min(dim=-1, keepdim=True)[0]
max_val = dummy_input.max(dim=-1, keepdim=True)[0]
dummy_input = (dummy_input - min_val) / (max_val - min_val + 1e-6)
with profiler.profile(
activities=[
profiler.ProfilerActivity.CPU,
profiler.ProfilerActivity.CUDA
],
with_stack=True,
profile_memory=True,
record_shapes=True
) as prof:
out = model(dummy_input)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
return 0
# Load dataset
train_loader, val_loader, test_loader, scaler, *extra_data = get_dataloader(
args,
normalizer=args['data']['normalizer'],
single=False
)
# Initialize loss function
loss = get_loss_function(args['train'], scaler)
# Initialize optimizer and learning rate scheduler
optimizer, lr_scheduler = init_optimizer(model, args['train'])
# Configure log path
current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
current_dir = os.path.dirname(os.path.realpath(__file__))
args['train']['log_dir'] = os.path.join(current_dir, 'experiments', args['data']['type'], current_time)
# 配置文件路径
config_filename = f"{args['data']['type']}.yaml"
# config_path = os.path.join(current_dir, 'config', args['model']['type'], config_filename)
# 确保日志目录存在
os.makedirs(args['train']['log_dir'], exist_ok=True)
# 生成配置文件内容(将 args 转换为 YAML 格式)
config_content = yaml.safe_dump(args, default_flow_style=False)
# 生成新的 YAML 文件名例如config.auto.yaml 或其他名称)
destination_path = os.path.join(args['train']['log_dir'], config_filename)
# 将 args 保存为 YAML 文件
with open(destination_path, 'w') as f:
f.write(config_content)
# Start training or testing
trainer = select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
lr_scheduler, extra_data)
match args['mode']:
case 'train':
trainer.train()
case 'test':
model.load_state_dict(torch.load(
f"./pre-trained/{args['model']['type']}/{args['data']['type']}.pth",
map_location=args['device'], weights_only=True))
# print(f"Loaded saved model on {args['device']}")
trainer.test(model.to(args['device']), trainer.args, test_loader, scaler, trainer.logger)
case _:
raise ValueError(f"Unsupported mode: {args['mode']}")
if __name__ == '__main__':
main()