fix STID multi output_dim
This commit is contained in:
parent
7e2fa532de
commit
c8416c5b01
|
|
@ -10,10 +10,10 @@ data:
|
|||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
input_dim: 6
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 12
|
||||
num_nodes: 35
|
||||
steps_per_day: 48
|
||||
test_ratio: 0.2
|
||||
val_ratio: 0.2
|
||||
|
|
@ -27,12 +27,12 @@ model:
|
|||
if_D_i_W: true
|
||||
if_T_i_D: true
|
||||
if_node: true
|
||||
input_dim: 3
|
||||
input_dim: 8
|
||||
input_len: 24
|
||||
node_dim: 32
|
||||
num_layer: 3
|
||||
num_nodes: 12
|
||||
output_dim: 1
|
||||
num_nodes: 35
|
||||
output_dim: 6
|
||||
output_len: 24
|
||||
temp_dim_diw: 32
|
||||
temp_dim_tid: 32
|
||||
|
|
@ -54,7 +54,7 @@ train:
|
|||
mae_thresh: 0.0
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
output_dim: 6
|
||||
plot: false
|
||||
real_value: true
|
||||
weight_decay: 0.0001
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ class STID(nn.Module):
|
|||
self.input_dim = model_args["input_dim"]
|
||||
self.embed_dim = model_args["embed_dim"]
|
||||
self.output_len = model_args["output_len"]
|
||||
self.output_dim = model_args["output_dim"]
|
||||
self.num_layer = model_args["num_layer"]
|
||||
self.temp_dim_tid = model_args["temp_dim_tid"]
|
||||
self.temp_dim_diw = model_args["temp_dim_diw"]
|
||||
|
|
@ -57,7 +58,7 @@ class STID(nn.Module):
|
|||
|
||||
self.regression_layer = nn.Conv2d(
|
||||
in_channels=self.hidden_dim,
|
||||
out_channels=self.output_len,
|
||||
out_channels=self.output_len * self.output_dim,
|
||||
kernel_size=(1, 1),
|
||||
bias=True,
|
||||
)
|
||||
|
|
@ -104,6 +105,8 @@ class STID(nn.Module):
|
|||
|
||||
hidden = torch.cat([time_series_emb] + node_emb + tem_emb, dim=1)
|
||||
hidden = self.encoder(hidden)
|
||||
prediction = self.regression_layer(hidden)
|
||||
prediction = prediction.permute(0, 1, 3, 2) # [B, t, n, c]
|
||||
prediction = self.regression_layer(hidden) # [B, output_len * output_dim, 1, N]
|
||||
prediction = prediction.squeeze(2).permute(0, 2, 1) # [B, N, output_len * output_dim]
|
||||
prediction = prediction.view(prediction.shape[0], prediction.shape[1], self.output_len, self.output_dim)
|
||||
prediction = prediction.permute(0, 2, 1, 3) # [B, output_len, N, output_dim]
|
||||
return prediction # [B, t, n, c]
|
||||
|
|
|
|||
8
train.py
8
train.py
|
|
@ -92,7 +92,7 @@ def read_config(config_path):
|
|||
# 全局配置
|
||||
device = "cuda:0" # 指定设备为cuda:0
|
||||
seed = 2023 # 随机种子
|
||||
epochs = 50 # 训练轮数
|
||||
epochs = 10 # 训练轮数
|
||||
|
||||
# 拷贝项
|
||||
config["basic"]["seed"] = seed
|
||||
|
|
@ -110,7 +110,7 @@ def read_config(config_path):
|
|||
if __name__ == "__main__":
|
||||
# 调试用
|
||||
# model_list = ["iTransformer", "PatchTST", "HI"]
|
||||
model_list = ["Informer"]
|
||||
model_list = ["STID"]
|
||||
# model_list = ["PatchTST"]
|
||||
|
||||
air = ["AirQuality"]
|
||||
|
|
@ -121,5 +121,5 @@ if __name__ == "__main__":
|
|||
|
||||
all_dataset = big_dataset + mid_dataset + regular_dataset
|
||||
|
||||
dataset_list = test_dataset
|
||||
main(model_list, dataset_list, debug=False)
|
||||
dataset_list = regular_dataset
|
||||
main(model_list, dataset_list, debug=True)
|
||||
|
|
|
|||
Loading…
Reference in New Issue