TrafficWheel/train_step.py

119 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
STEP模型训练脚本
"""
import torch
import yaml
import os
import sys
import argparse
# 添加项目根目录到路径
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from model.model_selector import model_selector
from dataloader.loader_selector import get_dataloader
from trainer.trainer_selector import select_trainer
from lib.loss_function import masked_mae_loss
def train_step_model(config_path, epochs=None):
"""训练STEP模型"""
print(f"开始训练STEP模型配置文件: {config_path}")
# 加载配置
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# 如果指定了epochs覆盖配置文件中的设置
if epochs is not None:
config['train']['epochs'] = epochs
# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
# 创建日志目录
log_dir = f'./logs/STEP_{config["data"]["type"]}'
os.makedirs(log_dir, exist_ok=True)
try:
# 创建模型
print("创建STEP模型...")
model = model_selector(config['model'])
model = model.to(device)
print(f"模型参数数量: {sum(p.numel() for p in model.parameters())}")
# 创建数据加载器
print("创建数据加载器...")
train_loader, val_loader, test_loader, scaler = get_dataloader(
config, normalizer='std', single=True
)
print(f"训练集批次数: {len(train_loader)}")
print(f"验证集批次数: {len(val_loader)}")
print(f"测试集批次数: {len(test_loader)}")
# 创建优化器
print("创建优化器...")
optimizer = torch.optim.Adam(
model.parameters(),
lr=config['train']['lr_init'],
weight_decay=config['train']['weight_decay']
)
# 创建学习率调度器
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=config['train']['lr_decay_step'],
gamma=config['train']['lr_decay_rate']
)
# 创建训练器
print("创建训练器...")
trainer = select_trainer(
model=model,
loss=masked_mae_loss,
optimizer=optimizer,
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
scaler=scaler,
args=config,
lr_scheduler=lr_scheduler,
kwargs=[]
)
# 开始训练
print(f"开始训练总epochs: {config['train']['epochs']}")
best_val_loss, best_test_loss = trainer.train()
print(f"训练完成!")
print(f"最佳验证损失: {best_val_loss:.4f}")
print(f"最佳测试损失: {best_test_loss:.4f}")
return True
except Exception as e:
print(f"STEP模型训练失败: {e}")
import traceback
traceback.print_exc()
return False
def main():
parser = argparse.ArgumentParser(description='训练STEP模型')
parser.add_argument('--config', type=str, default='config/STEP/STEP_PEMS04.yaml',
help='配置文件路径')
parser.add_argument('--epochs', type=int, default=None,
help='训练轮数(覆盖配置文件中的设置)')
args = parser.parse_args()
success = train_step_model(args.config, args.epochs)
if success:
print("\n✅ STEP模型训练完成")
else:
print("\n❌ STEP模型训练失败")
if __name__ == "__main__":
main()