123 lines
3.8 KiB
Python
123 lines
3.8 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
STEP模型测试脚本
|
||
"""
|
||
|
||
import torch
|
||
import yaml
|
||
import os
|
||
import sys
|
||
|
||
# 添加项目根目录到路径
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
from model.model_selector import model_selector
|
||
from dataloader.loader_selector import get_dataloader
|
||
from trainer.trainer_selector import select_trainer
|
||
from lib.loss_function import masked_mae_loss
|
||
from lib.normalization import normalize_dataset
|
||
|
||
def test_step_model():
|
||
"""测试STEP模型"""
|
||
print("开始测试STEP模型...")
|
||
|
||
# 加载配置
|
||
config_path = "config/STEP/STEP_PEMS04.yaml"
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config = yaml.safe_load(f)
|
||
|
||
print(f"加载配置文件: {config_path}")
|
||
|
||
# 设置设备
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
print(f"使用设备: {device}")
|
||
|
||
try:
|
||
# 创建模型
|
||
print("创建STEP模型...")
|
||
model = model_selector(config['model'])
|
||
model = model.to(device)
|
||
print(f"模型参数数量: {sum(p.numel() for p in model.parameters())}")
|
||
|
||
# 创建数据加载器
|
||
print("创建数据加载器...")
|
||
train_loader, val_loader, test_loader, scaler = get_dataloader(
|
||
config, normalizer='std', single=True
|
||
)
|
||
print(f"训练集批次数: {len(train_loader)}")
|
||
print(f"验证集批次数: {len(val_loader)}")
|
||
print(f"测试集批次数: {len(test_loader)}")
|
||
|
||
# 测试模型前向传播
|
||
print("测试模型前向传播...")
|
||
model.eval()
|
||
with torch.no_grad():
|
||
for batch_idx, (data, target) in enumerate(train_loader):
|
||
if batch_idx >= 1: # 只测试第一个批次
|
||
break
|
||
|
||
data = data.to(device)
|
||
target = target.to(device)
|
||
|
||
print(f"输入数据形状: {data.shape}")
|
||
print(f"目标数据形状: {target.shape}")
|
||
|
||
# 前向传播
|
||
output = model(data)
|
||
print(f"输出数据形状: {output.shape}")
|
||
|
||
# 测试损失计算
|
||
loss_fn = masked_mae_loss(None, None)
|
||
loss = loss_fn(output, target)
|
||
print(f"损失值: {loss.item():.4f}")
|
||
|
||
break
|
||
|
||
# 创建优化器
|
||
print("创建优化器...")
|
||
optimizer = torch.optim.Adam(
|
||
model.parameters(),
|
||
lr=config['train']['lr_init'],
|
||
weight_decay=config['train']['weight_decay']
|
||
)
|
||
|
||
# 创建学习率调度器
|
||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||
optimizer,
|
||
milestones=config['train']['lr_decay_step'],
|
||
gamma=config['train']['lr_decay_rate']
|
||
)
|
||
|
||
# 创建训练器
|
||
print("创建训练器...")
|
||
trainer = select_trainer(
|
||
model=model,
|
||
loss=masked_mae_loss,
|
||
optimizer=optimizer,
|
||
train_loader=train_loader,
|
||
val_loader=val_loader,
|
||
test_loader=test_loader,
|
||
scaler=scaler,
|
||
args=config,
|
||
lr_scheduler=lr_scheduler,
|
||
kwargs=[]
|
||
)
|
||
|
||
print("STEP模型测试完成!")
|
||
print("模型可以正常创建、前向传播和训练。")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"STEP模型测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
if __name__ == "__main__":
|
||
success = test_step_model()
|
||
if success:
|
||
print("\n✅ STEP模型适配成功!")
|
||
else:
|
||
print("\n❌ STEP模型适配失败!")
|