TrafficWheel/test_informer.py

57 lines
1.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from model.model_selector import model_selector
import yaml
# 读取配置文件
with open('/user/czzhangheng/code/TrafficWheel/config/Informer/AirQuality.yaml', 'r') as f:
config = yaml.safe_load(f)
# 初始化模型
model = model_selector(config)
print('Informer模型初始化成功')
print(f'模型参数数量: {sum(p.numel() for p in model.parameters())}')
# 创建测试数据
B, T, C = 2, 24, 6
x_enc = torch.randn(B, T, C)
# 测试1: 完整参数
print('\n测试1: 完整参数')
x_mark_enc = torch.randn(B, T, 4) # 假设时间特征为4维
x_dec = torch.randn(B, 12+24, C) # label_len + pred_len
x_mark_dec = torch.randn(B, 12+24, 4)
try:
output = model(x_enc, x_mark_enc, x_dec, x_mark_dec)
print(f'输出形状: {output.shape}')
print('测试1通过')
except Exception as e:
print(f'测试1失败: {e}')
# 测试2: 省略x_mark_enc
print('\n测试2: 省略x_mark_enc')
try:
output = model(x_enc, x_dec=x_dec, x_mark_dec=x_mark_dec)
print(f'输出形状: {output.shape}')
print('测试2通过')
except Exception as e:
print(f'测试2失败: {e}')
# 测试3: 省略x_dec和x_mark_dec
print('\n测试3: 省略x_dec和x_mark_dec')
try:
output = model(x_enc, x_mark_enc=x_mark_enc)
print(f'输出形状: {output.shape}')
print('测试3通过')
except Exception as e:
print(f'测试3失败: {e}')
# 测试4: 仅传入x_enc
print('\n测试4: 仅传入x_enc')
try:
output = model(x_enc)
print(f'输出形状: {output.shape}')
print('测试4通过')
except Exception as e:
print(f'测试4失败: {e}')
print('\n所有测试完成!')