fix STID multi output_dim

This commit is contained in:
czzhangheng 2026-01-04 10:01:01 +08:00
parent 7e2fa532de
commit c8416c5b01
3 changed files with 16 additions and 13 deletions

View File

@ -10,10 +10,10 @@ data:
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
input_dim: 1 input_dim: 6
lag: 24 lag: 24
normalizer: std normalizer: std
num_nodes: 12 num_nodes: 35
steps_per_day: 48 steps_per_day: 48
test_ratio: 0.2 test_ratio: 0.2
val_ratio: 0.2 val_ratio: 0.2
@ -27,12 +27,12 @@ model:
if_D_i_W: true if_D_i_W: true
if_T_i_D: true if_T_i_D: true
if_node: true if_node: true
input_dim: 3 input_dim: 8
input_len: 24 input_len: 24
node_dim: 32 node_dim: 32
num_layer: 3 num_layer: 3
num_nodes: 12 num_nodes: 35
output_dim: 1 output_dim: 6
output_len: 24 output_len: 24
temp_dim_diw: 32 temp_dim_diw: 32
temp_dim_tid: 32 temp_dim_tid: 32
@ -54,7 +54,7 @@ train:
mae_thresh: 0.0 mae_thresh: 0.0
mape_thresh: 0.001 mape_thresh: 0.001
max_grad_norm: 5 max_grad_norm: 5
output_dim: 1 output_dim: 6
plot: false plot: false
real_value: true real_value: true
weight_decay: 0.0001 weight_decay: 0.0001

View File

@ -12,6 +12,7 @@ class STID(nn.Module):
self.input_dim = model_args["input_dim"] self.input_dim = model_args["input_dim"]
self.embed_dim = model_args["embed_dim"] self.embed_dim = model_args["embed_dim"]
self.output_len = model_args["output_len"] self.output_len = model_args["output_len"]
self.output_dim = model_args["output_dim"]
self.num_layer = model_args["num_layer"] self.num_layer = model_args["num_layer"]
self.temp_dim_tid = model_args["temp_dim_tid"] self.temp_dim_tid = model_args["temp_dim_tid"]
self.temp_dim_diw = model_args["temp_dim_diw"] self.temp_dim_diw = model_args["temp_dim_diw"]
@ -57,7 +58,7 @@ class STID(nn.Module):
self.regression_layer = nn.Conv2d( self.regression_layer = nn.Conv2d(
in_channels=self.hidden_dim, in_channels=self.hidden_dim,
out_channels=self.output_len, out_channels=self.output_len * self.output_dim,
kernel_size=(1, 1), kernel_size=(1, 1),
bias=True, bias=True,
) )
@ -104,6 +105,8 @@ class STID(nn.Module):
hidden = torch.cat([time_series_emb] + node_emb + tem_emb, dim=1) hidden = torch.cat([time_series_emb] + node_emb + tem_emb, dim=1)
hidden = self.encoder(hidden) hidden = self.encoder(hidden)
prediction = self.regression_layer(hidden) prediction = self.regression_layer(hidden) # [B, output_len * output_dim, 1, N]
prediction = prediction.permute(0, 1, 3, 2) # [B, t, n, c] prediction = prediction.squeeze(2).permute(0, 2, 1) # [B, N, output_len * output_dim]
prediction = prediction.view(prediction.shape[0], prediction.shape[1], self.output_len, self.output_dim)
prediction = prediction.permute(0, 2, 1, 3) # [B, output_len, N, output_dim]
return prediction # [B, t, n, c] return prediction # [B, t, n, c]

View File

@ -92,7 +92,7 @@ def read_config(config_path):
# 全局配置 # 全局配置
device = "cuda:0" # 指定设备为cuda:0 device = "cuda:0" # 指定设备为cuda:0
seed = 2023 # 随机种子 seed = 2023 # 随机种子
epochs = 50 # 训练轮数 epochs = 10 # 训练轮数
# 拷贝项 # 拷贝项
config["basic"]["seed"] = seed config["basic"]["seed"] = seed
@ -110,7 +110,7 @@ def read_config(config_path):
if __name__ == "__main__": if __name__ == "__main__":
# 调试用 # 调试用
# model_list = ["iTransformer", "PatchTST", "HI"] # model_list = ["iTransformer", "PatchTST", "HI"]
model_list = ["Informer"] model_list = ["STID"]
# model_list = ["PatchTST"] # model_list = ["PatchTST"]
air = ["AirQuality"] air = ["AirQuality"]
@ -121,5 +121,5 @@ if __name__ == "__main__":
all_dataset = big_dataset + mid_dataset + regular_dataset all_dataset = big_dataset + mid_dataset + regular_dataset
dataset_list = test_dataset dataset_list = regular_dataset
main(model_list, dataset_list, debug=False) main(model_list, dataset_list, debug=True)