import torch from model.model_selector import model_selector import yaml # 读取配置文件 with open('/user/czzhangheng/code/TrafficWheel/config/Informer/AirQuality.yaml', 'r') as f: config = yaml.safe_load(f) # 初始化模型 model = model_selector(config) print('Informer模型初始化成功!') print(f'模型参数数量: {sum(p.numel() for p in model.parameters())}') # 创建测试数据 B, T, C = 2, 24, 6 x_enc = torch.randn(B, T, C) # 测试1: 完整参数 print('\n测试1: 完整参数') x_mark_enc = torch.randn(B, T, 4) # 假设时间特征为4维 x_dec = torch.randn(B, 12+24, C) # label_len + pred_len x_mark_dec = torch.randn(B, 12+24, 4) try: output = model(x_enc, x_mark_enc, x_dec, x_mark_dec) print(f'输出形状: {output.shape}') print('测试1通过!') except Exception as e: print(f'测试1失败: {e}') # 测试2: 省略x_mark_enc print('\n测试2: 省略x_mark_enc') try: output = model(x_enc, x_dec=x_dec, x_mark_dec=x_mark_dec) print(f'输出形状: {output.shape}') print('测试2通过!') except Exception as e: print(f'测试2失败: {e}') # 测试3: 省略x_dec和x_mark_dec print('\n测试3: 省略x_dec和x_mark_dec') try: output = model(x_enc, x_mark_enc=x_mark_enc) print(f'输出形状: {output.shape}') print('测试3通过!') except Exception as e: print(f'测试3失败: {e}') # 测试4: 仅传入x_enc print('\n测试4: 仅传入x_enc') try: output = model(x_enc) print(f'输出形状: {output.shape}') print('测试4通过!') except Exception as e: print(f'测试4失败: {e}') print('\n所有测试完成!')