#!/usr/bin/env python3 """ STEP模型训练脚本 """ import torch import yaml import os import sys import argparse # 添加项目根目录到路径 sys.path.append(os.path.dirname(os.path.abspath(__file__))) from model.model_selector import model_selector from dataloader.loader_selector import get_dataloader from trainer.trainer_selector import select_trainer from lib.loss_function import masked_mae_loss def train_step_model(config_path, epochs=None): """训练STEP模型""" print(f"开始训练STEP模型,配置文件: {config_path}") # 加载配置 with open(config_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) # 如果指定了epochs,覆盖配置文件中的设置 if epochs is not None: config['train']['epochs'] = epochs # 设置设备 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"使用设备: {device}") # 创建日志目录 log_dir = f'./logs/STEP_{config["data"]["type"]}' os.makedirs(log_dir, exist_ok=True) try: # 创建模型 print("创建STEP模型...") model = model_selector(config['model']) model = model.to(device) print(f"模型参数数量: {sum(p.numel() for p in model.parameters())}") # 创建数据加载器 print("创建数据加载器...") train_loader, val_loader, test_loader, scaler = get_dataloader( config, normalizer='std', single=True ) print(f"训练集批次数: {len(train_loader)}") print(f"验证集批次数: {len(val_loader)}") print(f"测试集批次数: {len(test_loader)}") # 创建优化器 print("创建优化器...") optimizer = torch.optim.Adam( model.parameters(), lr=config['train']['lr_init'], weight_decay=config['train']['weight_decay'] ) # 创建学习率调度器 lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=config['train']['lr_decay_step'], gamma=config['train']['lr_decay_rate'] ) # 创建训练器 print("创建训练器...") trainer = select_trainer( model=model, loss=masked_mae_loss, optimizer=optimizer, train_loader=train_loader, val_loader=val_loader, test_loader=test_loader, scaler=scaler, args=config, lr_scheduler=lr_scheduler, kwargs=[] ) # 开始训练 print(f"开始训练,总epochs: {config['train']['epochs']}") best_val_loss, best_test_loss = trainer.train() print(f"训练完成!") print(f"最佳验证损失: {best_val_loss:.4f}") print(f"最佳测试损失: {best_test_loss:.4f}") return True except Exception as e: print(f"STEP模型训练失败: {e}") import traceback traceback.print_exc() return False def main(): parser = argparse.ArgumentParser(description='训练STEP模型') parser.add_argument('--config', type=str, default='config/STEP/STEP_PEMS04.yaml', help='配置文件路径') parser.add_argument('--epochs', type=int, default=None, help='训练轮数(覆盖配置文件中的设置)') args = parser.parse_args() success = train_step_model(args.config, args.epochs) if success: print("\n✅ STEP模型训练完成!") else: print("\n❌ STEP模型训练失败!") if __name__ == "__main__": main()