diff --git a/config/STID/AirQuality.yaml b/config/STID/AirQuality.yaml index f499161..b480b4f 100755 --- a/config/STID/AirQuality.yaml +++ b/config/STID/AirQuality.yaml @@ -10,10 +10,10 @@ data: column_wise: false days_per_week: 7 horizon: 24 - input_dim: 1 + input_dim: 6 lag: 24 normalizer: std - num_nodes: 12 + num_nodes: 35 steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 @@ -27,12 +27,12 @@ model: if_D_i_W: true if_T_i_D: true if_node: true - input_dim: 3 + input_dim: 8 input_len: 24 node_dim: 32 num_layer: 3 - num_nodes: 12 - output_dim: 1 + num_nodes: 35 + output_dim: 6 output_len: 24 temp_dim_diw: 32 temp_dim_tid: 32 @@ -54,7 +54,7 @@ train: mae_thresh: 0.0 mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 1 + output_dim: 6 plot: false real_value: true weight_decay: 0.0001 diff --git a/model/STID/STID.py b/model/STID/STID.py index 713bc02..6d1ff0c 100755 --- a/model/STID/STID.py +++ b/model/STID/STID.py @@ -12,6 +12,7 @@ class STID(nn.Module): self.input_dim = model_args["input_dim"] self.embed_dim = model_args["embed_dim"] self.output_len = model_args["output_len"] + self.output_dim = model_args["output_dim"] self.num_layer = model_args["num_layer"] self.temp_dim_tid = model_args["temp_dim_tid"] self.temp_dim_diw = model_args["temp_dim_diw"] @@ -57,7 +58,7 @@ class STID(nn.Module): self.regression_layer = nn.Conv2d( in_channels=self.hidden_dim, - out_channels=self.output_len, + out_channels=self.output_len * self.output_dim, kernel_size=(1, 1), bias=True, ) @@ -104,6 +105,8 @@ class STID(nn.Module): hidden = torch.cat([time_series_emb] + node_emb + tem_emb, dim=1) hidden = self.encoder(hidden) - prediction = self.regression_layer(hidden) - prediction = prediction.permute(0, 1, 3, 2) # [B, t, n, c] + prediction = self.regression_layer(hidden) # [B, output_len * output_dim, 1, N] + prediction = prediction.squeeze(2).permute(0, 2, 1) # [B, N, output_len * output_dim] + prediction = prediction.view(prediction.shape[0], prediction.shape[1], self.output_len, self.output_dim) + prediction = prediction.permute(0, 2, 1, 3) # [B, output_len, N, output_dim] return prediction # [B, t, n, c] diff --git a/train.py b/train.py index 2294713..ecf1f01 100644 --- a/train.py +++ b/train.py @@ -92,7 +92,7 @@ def read_config(config_path): # 全局配置 device = "cuda:0" # 指定设备为cuda:0 seed = 2023 # 随机种子 - epochs = 50 # 训练轮数 + epochs = 10 # 训练轮数 # 拷贝项 config["basic"]["seed"] = seed @@ -110,7 +110,7 @@ def read_config(config_path): if __name__ == "__main__": # 调试用 # model_list = ["iTransformer", "PatchTST", "HI"] - model_list = ["Informer"] + model_list = ["STID"] # model_list = ["PatchTST"] air = ["AirQuality"] @@ -121,5 +121,5 @@ if __name__ == "__main__": all_dataset = big_dataset + mid_dataset + regular_dataset - dataset_list = test_dataset - main(model_list, dataset_list, debug=False) + dataset_list = regular_dataset + main(model_list, dataset_list, debug=True)