TrafficWheel/test_step.py

123 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
STEP模型测试脚本
"""
import torch
import yaml
import os
import sys
# 添加项目根目录到路径
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from model.model_selector import model_selector
from dataloader.loader_selector import get_dataloader
from trainer.trainer_selector import select_trainer
from lib.loss_function import masked_mae_loss
from lib.normalization import normalize_dataset
def test_step_model():
"""测试STEP模型"""
print("开始测试STEP模型...")
# 加载配置
config_path = "config/STEP/STEP_PEMS04.yaml"
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
print(f"加载配置文件: {config_path}")
# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
try:
# 创建模型
print("创建STEP模型...")
model = model_selector(config['model'])
model = model.to(device)
print(f"模型参数数量: {sum(p.numel() for p in model.parameters())}")
# 创建数据加载器
print("创建数据加载器...")
train_loader, val_loader, test_loader, scaler = get_dataloader(
config, normalizer='std', single=True
)
print(f"训练集批次数: {len(train_loader)}")
print(f"验证集批次数: {len(val_loader)}")
print(f"测试集批次数: {len(test_loader)}")
# 测试模型前向传播
print("测试模型前向传播...")
model.eval()
with torch.no_grad():
for batch_idx, (data, target) in enumerate(train_loader):
if batch_idx >= 1: # 只测试第一个批次
break
data = data.to(device)
target = target.to(device)
print(f"输入数据形状: {data.shape}")
print(f"目标数据形状: {target.shape}")
# 前向传播
output = model(data)
print(f"输出数据形状: {output.shape}")
# 测试损失计算
loss_fn = masked_mae_loss(None, None)
loss = loss_fn(output, target)
print(f"损失值: {loss.item():.4f}")
break
# 创建优化器
print("创建优化器...")
optimizer = torch.optim.Adam(
model.parameters(),
lr=config['train']['lr_init'],
weight_decay=config['train']['weight_decay']
)
# 创建学习率调度器
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=config['train']['lr_decay_step'],
gamma=config['train']['lr_decay_rate']
)
# 创建训练器
print("创建训练器...")
trainer = select_trainer(
model=model,
loss=masked_mae_loss,
optimizer=optimizer,
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
scaler=scaler,
args=config,
lr_scheduler=lr_scheduler,
kwargs=[]
)
print("STEP模型测试完成")
print("模型可以正常创建、前向传播和训练。")
return True
except Exception as e:
print(f"STEP模型测试失败: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
success = test_step_model()
if success:
print("\n✅ STEP模型适配成功")
else:
print("\n❌ STEP模型适配失败")