fix STID multi output_dim
This commit is contained in:
parent
7e2fa532de
commit
c8416c5b01
|
|
@ -10,10 +10,10 @@ data:
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
horizon: 24
|
horizon: 24
|
||||||
input_dim: 1
|
input_dim: 6
|
||||||
lag: 24
|
lag: 24
|
||||||
normalizer: std
|
normalizer: std
|
||||||
num_nodes: 12
|
num_nodes: 35
|
||||||
steps_per_day: 48
|
steps_per_day: 48
|
||||||
test_ratio: 0.2
|
test_ratio: 0.2
|
||||||
val_ratio: 0.2
|
val_ratio: 0.2
|
||||||
|
|
@ -27,12 +27,12 @@ model:
|
||||||
if_D_i_W: true
|
if_D_i_W: true
|
||||||
if_T_i_D: true
|
if_T_i_D: true
|
||||||
if_node: true
|
if_node: true
|
||||||
input_dim: 3
|
input_dim: 8
|
||||||
input_len: 24
|
input_len: 24
|
||||||
node_dim: 32
|
node_dim: 32
|
||||||
num_layer: 3
|
num_layer: 3
|
||||||
num_nodes: 12
|
num_nodes: 35
|
||||||
output_dim: 1
|
output_dim: 6
|
||||||
output_len: 24
|
output_len: 24
|
||||||
temp_dim_diw: 32
|
temp_dim_diw: 32
|
||||||
temp_dim_tid: 32
|
temp_dim_tid: 32
|
||||||
|
|
@ -54,7 +54,7 @@ train:
|
||||||
mae_thresh: 0.0
|
mae_thresh: 0.0
|
||||||
mape_thresh: 0.001
|
mape_thresh: 0.001
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
output_dim: 1
|
output_dim: 6
|
||||||
plot: false
|
plot: false
|
||||||
real_value: true
|
real_value: true
|
||||||
weight_decay: 0.0001
|
weight_decay: 0.0001
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ class STID(nn.Module):
|
||||||
self.input_dim = model_args["input_dim"]
|
self.input_dim = model_args["input_dim"]
|
||||||
self.embed_dim = model_args["embed_dim"]
|
self.embed_dim = model_args["embed_dim"]
|
||||||
self.output_len = model_args["output_len"]
|
self.output_len = model_args["output_len"]
|
||||||
|
self.output_dim = model_args["output_dim"]
|
||||||
self.num_layer = model_args["num_layer"]
|
self.num_layer = model_args["num_layer"]
|
||||||
self.temp_dim_tid = model_args["temp_dim_tid"]
|
self.temp_dim_tid = model_args["temp_dim_tid"]
|
||||||
self.temp_dim_diw = model_args["temp_dim_diw"]
|
self.temp_dim_diw = model_args["temp_dim_diw"]
|
||||||
|
|
@ -57,7 +58,7 @@ class STID(nn.Module):
|
||||||
|
|
||||||
self.regression_layer = nn.Conv2d(
|
self.regression_layer = nn.Conv2d(
|
||||||
in_channels=self.hidden_dim,
|
in_channels=self.hidden_dim,
|
||||||
out_channels=self.output_len,
|
out_channels=self.output_len * self.output_dim,
|
||||||
kernel_size=(1, 1),
|
kernel_size=(1, 1),
|
||||||
bias=True,
|
bias=True,
|
||||||
)
|
)
|
||||||
|
|
@ -104,6 +105,8 @@ class STID(nn.Module):
|
||||||
|
|
||||||
hidden = torch.cat([time_series_emb] + node_emb + tem_emb, dim=1)
|
hidden = torch.cat([time_series_emb] + node_emb + tem_emb, dim=1)
|
||||||
hidden = self.encoder(hidden)
|
hidden = self.encoder(hidden)
|
||||||
prediction = self.regression_layer(hidden)
|
prediction = self.regression_layer(hidden) # [B, output_len * output_dim, 1, N]
|
||||||
prediction = prediction.permute(0, 1, 3, 2) # [B, t, n, c]
|
prediction = prediction.squeeze(2).permute(0, 2, 1) # [B, N, output_len * output_dim]
|
||||||
|
prediction = prediction.view(prediction.shape[0], prediction.shape[1], self.output_len, self.output_dim)
|
||||||
|
prediction = prediction.permute(0, 2, 1, 3) # [B, output_len, N, output_dim]
|
||||||
return prediction # [B, t, n, c]
|
return prediction # [B, t, n, c]
|
||||||
|
|
|
||||||
8
train.py
8
train.py
|
|
@ -92,7 +92,7 @@ def read_config(config_path):
|
||||||
# 全局配置
|
# 全局配置
|
||||||
device = "cuda:0" # 指定设备为cuda:0
|
device = "cuda:0" # 指定设备为cuda:0
|
||||||
seed = 2023 # 随机种子
|
seed = 2023 # 随机种子
|
||||||
epochs = 50 # 训练轮数
|
epochs = 10 # 训练轮数
|
||||||
|
|
||||||
# 拷贝项
|
# 拷贝项
|
||||||
config["basic"]["seed"] = seed
|
config["basic"]["seed"] = seed
|
||||||
|
|
@ -110,7 +110,7 @@ def read_config(config_path):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 调试用
|
# 调试用
|
||||||
# model_list = ["iTransformer", "PatchTST", "HI"]
|
# model_list = ["iTransformer", "PatchTST", "HI"]
|
||||||
model_list = ["Informer"]
|
model_list = ["STID"]
|
||||||
# model_list = ["PatchTST"]
|
# model_list = ["PatchTST"]
|
||||||
|
|
||||||
air = ["AirQuality"]
|
air = ["AirQuality"]
|
||||||
|
|
@ -121,5 +121,5 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
all_dataset = big_dataset + mid_dataset + regular_dataset
|
all_dataset = big_dataset + mid_dataset + regular_dataset
|
||||||
|
|
||||||
dataset_list = test_dataset
|
dataset_list = regular_dataset
|
||||||
main(model_list, dataset_list, debug=False)
|
main(model_list, dataset_list, debug=True)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue