57 lines
1.6 KiB
Python
57 lines
1.6 KiB
Python
import torch
|
||
from model.model_selector import model_selector
|
||
import yaml
|
||
|
||
# 读取配置文件
|
||
with open('/user/czzhangheng/code/TrafficWheel/config/Informer/AirQuality.yaml', 'r') as f:
|
||
config = yaml.safe_load(f)
|
||
|
||
# 初始化模型
|
||
model = model_selector(config)
|
||
print('Informer模型初始化成功!')
|
||
print(f'模型参数数量: {sum(p.numel() for p in model.parameters())}')
|
||
|
||
# 创建测试数据
|
||
B, T, C = 2, 24, 6
|
||
x_enc = torch.randn(B, T, C)
|
||
|
||
# 测试1: 完整参数
|
||
print('\n测试1: 完整参数')
|
||
x_mark_enc = torch.randn(B, T, 4) # 假设时间特征为4维
|
||
x_dec = torch.randn(B, 12+24, C) # label_len + pred_len
|
||
x_mark_dec = torch.randn(B, 12+24, 4)
|
||
try:
|
||
output = model(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
||
print(f'输出形状: {output.shape}')
|
||
print('测试1通过!')
|
||
except Exception as e:
|
||
print(f'测试1失败: {e}')
|
||
|
||
# 测试2: 省略x_mark_enc
|
||
print('\n测试2: 省略x_mark_enc')
|
||
try:
|
||
output = model(x_enc, x_dec=x_dec, x_mark_dec=x_mark_dec)
|
||
print(f'输出形状: {output.shape}')
|
||
print('测试2通过!')
|
||
except Exception as e:
|
||
print(f'测试2失败: {e}')
|
||
|
||
# 测试3: 省略x_dec和x_mark_dec
|
||
print('\n测试3: 省略x_dec和x_mark_dec')
|
||
try:
|
||
output = model(x_enc, x_mark_enc=x_mark_enc)
|
||
print(f'输出形状: {output.shape}')
|
||
print('测试3通过!')
|
||
except Exception as e:
|
||
print(f'测试3失败: {e}')
|
||
|
||
# 测试4: 仅传入x_enc
|
||
print('\n测试4: 仅传入x_enc')
|
||
try:
|
||
output = model(x_enc)
|
||
print(f'输出形状: {output.shape}')
|
||
print('测试4通过!')
|
||
except Exception as e:
|
||
print(f'测试4失败: {e}')
|
||
|
||
print('\n所有测试完成!') |