#!/usr/bin/env python3 """ STEP模型测试脚本 """ import torch import yaml import os import sys # 添加项目根目录到路径 sys.path.append(os.path.dirname(os.path.abspath(__file__))) from model.model_selector import model_selector from dataloader.loader_selector import get_dataloader from trainer.trainer_selector import select_trainer from lib.loss_function import masked_mae_loss from lib.normalization import normalize_dataset def test_step_model(): """测试STEP模型""" print("开始测试STEP模型...") # 加载配置 config_path = "config/STEP/STEP_PEMS04.yaml" with open(config_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) print(f"加载配置文件: {config_path}") # 设置设备 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"使用设备: {device}") try: # 创建模型 print("创建STEP模型...") model = model_selector(config['model']) model = model.to(device) print(f"模型参数数量: {sum(p.numel() for p in model.parameters())}") # 创建数据加载器 print("创建数据加载器...") train_loader, val_loader, test_loader, scaler = get_dataloader( config, normalizer='std', single=True ) print(f"训练集批次数: {len(train_loader)}") print(f"验证集批次数: {len(val_loader)}") print(f"测试集批次数: {len(test_loader)}") # 测试模型前向传播 print("测试模型前向传播...") model.eval() with torch.no_grad(): for batch_idx, (data, target) in enumerate(train_loader): if batch_idx >= 1: # 只测试第一个批次 break data = data.to(device) target = target.to(device) print(f"输入数据形状: {data.shape}") print(f"目标数据形状: {target.shape}") # 前向传播 output = model(data) print(f"输出数据形状: {output.shape}") # 测试损失计算 loss_fn = masked_mae_loss(None, None) loss = loss_fn(output, target) print(f"损失值: {loss.item():.4f}") break # 创建优化器 print("创建优化器...") optimizer = torch.optim.Adam( model.parameters(), lr=config['train']['lr_init'], weight_decay=config['train']['weight_decay'] ) # 创建学习率调度器 lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=config['train']['lr_decay_step'], gamma=config['train']['lr_decay_rate'] ) # 创建训练器 print("创建训练器...") trainer = select_trainer( model=model, loss=masked_mae_loss, optimizer=optimizer, train_loader=train_loader, val_loader=val_loader, test_loader=test_loader, scaler=scaler, args=config, lr_scheduler=lr_scheduler, kwargs=[] ) print("STEP模型测试完成!") print("模型可以正常创建、前向传播和训练。") return True except Exception as e: print(f"STEP模型测试失败: {e}") import traceback traceback.print_exc() return False if __name__ == "__main__": success = test_step_model() if success: print("\n✅ STEP模型适配成功!") else: print("\n❌ STEP模型适配失败!")