impl D2STGNN
This commit is contained in:
parent
9d3293cef7
commit
b46c16815e
|
|
@ -0,0 +1,60 @@
|
|||
basic:
|
||||
dataset: AirQuality
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: D2STGNN
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 64
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
input_dim: 6
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 35
|
||||
steps_per_day: 24
|
||||
test_ratio: 0.2
|
||||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
num_nodes: 35
|
||||
num_layers: 4
|
||||
num_hidden: 32
|
||||
forecast_dim: 256
|
||||
output_hidden: 512
|
||||
output_dim: 6
|
||||
seq_len: 24
|
||||
horizon: 24
|
||||
input_dim: 6
|
||||
num_timesteps_in_day: 24
|
||||
time_emb_dim: 10
|
||||
node_hidden: 10
|
||||
dy_graph: True
|
||||
sta_graph: False
|
||||
gap: 3
|
||||
k_s: 2
|
||||
k_t: 3
|
||||
dropout: 0.1
|
||||
|
||||
train:
|
||||
batch_size: 64
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_step: 1000
|
||||
loss_func: mae
|
||||
lr_decay: true
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 6
|
||||
plot: false
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
basic:
|
||||
dataset: BJTaxi-InFlow
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: D2STGNN
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 16
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 1024
|
||||
steps_per_day: 48
|
||||
test_ratio: 0.2
|
||||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
num_nodes: 1024
|
||||
num_layers: 4
|
||||
num_hidden: 32
|
||||
forecast_dim: 256
|
||||
output_hidden: 512
|
||||
output_dim: 1
|
||||
seq_len: 24
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
num_timesteps_in_day: 48
|
||||
time_emb_dim: 10
|
||||
node_hidden: 10
|
||||
dy_graph: True
|
||||
sta_graph: False
|
||||
gap: 3
|
||||
k_s: 2
|
||||
k_t: 3
|
||||
dropout: 0.1
|
||||
|
||||
train:
|
||||
batch_size: 16
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_step: 1000
|
||||
loss_func: mae
|
||||
lr_decay: true
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
plot: false
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
basic:
|
||||
dataset: BJTaxi-OutFlow
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: D2STGNN
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 16
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 1024
|
||||
steps_per_day: 48
|
||||
test_ratio: 0.2
|
||||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
num_nodes: 1024
|
||||
num_layers: 4
|
||||
num_hidden: 32
|
||||
forecast_dim: 256
|
||||
output_hidden: 512
|
||||
output_dim: 1
|
||||
seq_len: 24
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
num_timesteps_in_day: 48
|
||||
time_emb_dim: 10
|
||||
node_hidden: 10
|
||||
dy_graph: True
|
||||
sta_graph: False
|
||||
gap: 3
|
||||
k_s: 2
|
||||
k_t: 3
|
||||
dropout: 0.1
|
||||
|
||||
train:
|
||||
batch_size: 16
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_step: 1000
|
||||
loss_func: mae
|
||||
lr_decay: true
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
plot: false
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
basic:
|
||||
dataset: METR-LA
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: D2STGNN
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 16
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 207
|
||||
steps_per_day: 288
|
||||
test_ratio: 0.2
|
||||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
num_nodes: 207
|
||||
num_layers: 4
|
||||
num_hidden: 32
|
||||
forecast_dim: 256
|
||||
output_hidden: 512
|
||||
output_dim: 1
|
||||
seq_len: 24
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
num_timesteps_in_day: 288
|
||||
time_emb_dim: 10
|
||||
node_hidden: 10
|
||||
dy_graph: True
|
||||
sta_graph: False
|
||||
gap: 3
|
||||
k_s: 2
|
||||
k_t: 3
|
||||
dropout: 0.1
|
||||
|
||||
train:
|
||||
batch_size: 16
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_step: 1000
|
||||
loss_func: mae
|
||||
lr_decay: true
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
plot: false
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
basic:
|
||||
dataset: NYCBike-InFlow
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: D2STGNN
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 64
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 128
|
||||
steps_per_day: 48
|
||||
test_ratio: 0.2
|
||||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
num_nodes: 128
|
||||
num_layers: 4
|
||||
num_hidden: 32
|
||||
forecast_dim: 256
|
||||
output_hidden: 512
|
||||
output_dim: 1
|
||||
seq_len: 24
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
num_timesteps_in_day: 48
|
||||
time_emb_dim: 10
|
||||
node_hidden: 10
|
||||
dy_graph: True
|
||||
sta_graph: False
|
||||
gap: 3
|
||||
k_s: 2
|
||||
k_t: 3
|
||||
dropout: 0.1
|
||||
|
||||
train:
|
||||
batch_size: 64
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_step: 1000
|
||||
loss_func: mae
|
||||
lr_decay: true
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
plot: false
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
basic:
|
||||
dataset: NYCBike-OutFlow
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: D2STGNN
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 16
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 128
|
||||
steps_per_day: 48
|
||||
test_ratio: 0.2
|
||||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
num_nodes: 128
|
||||
num_layers: 4
|
||||
num_hidden: 32
|
||||
forecast_dim: 256
|
||||
output_hidden: 512
|
||||
output_dim: 1
|
||||
seq_len: 24
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
num_timesteps_in_day: 48
|
||||
time_emb_dim: 10
|
||||
node_hidden: 10
|
||||
dy_graph: True
|
||||
sta_graph: False
|
||||
gap: 3
|
||||
k_s: 2
|
||||
k_t: 3
|
||||
dropout: 0.1
|
||||
|
||||
train:
|
||||
batch_size: 16
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_step: 1000
|
||||
loss_func: mae
|
||||
lr_decay: true
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
plot: false
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
basic:
|
||||
dataset: PEMS-BAY
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: D2STGNN
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 16
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 325
|
||||
steps_per_day: 288
|
||||
test_ratio: 0.2
|
||||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
num_nodes: 325
|
||||
num_layers: 4
|
||||
num_hidden: 32
|
||||
forecast_dim: 256
|
||||
output_hidden: 512
|
||||
output_dim: 1
|
||||
seq_len: 24
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
num_timesteps_in_day: 288
|
||||
time_emb_dim: 10
|
||||
node_hidden: 10
|
||||
dy_graph: True
|
||||
sta_graph: False
|
||||
gap: 3
|
||||
k_s: 2
|
||||
k_t: 3
|
||||
dropout: 0.1
|
||||
|
||||
train:
|
||||
batch_size: 16
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_step: 1000
|
||||
loss_func: mae
|
||||
lr_decay: true
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
plot: false
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
basic:
|
||||
dataset: SolarEnergy
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: D2STGNN
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 64
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 137
|
||||
steps_per_day: 24
|
||||
test_ratio: 0.2
|
||||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
num_nodes: 137
|
||||
num_layers: 4
|
||||
num_hidden: 32
|
||||
forecast_dim: 256
|
||||
output_hidden: 512
|
||||
output_dim: 1
|
||||
seq_len: 24
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
num_timesteps_in_day: 24
|
||||
time_emb_dim: 10
|
||||
node_hidden: 10
|
||||
dy_graph: True
|
||||
sta_graph: False
|
||||
gap: 3
|
||||
k_s: 2
|
||||
k_t: 3
|
||||
dropout: 0.1
|
||||
|
||||
train:
|
||||
batch_size: 64
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_step: 1000
|
||||
loss_func: mae
|
||||
lr_decay: true
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
plot: false
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from model.D2STGNN.diffusion_block.dif_block import DifBlock
|
||||
from model.D2STGNN.inherent_block.inh_block import InhBlock
|
||||
from model.D2STGNN.dynamic_graph_conv.dy_graph_conv import DynamicGraphConstructor
|
||||
from model.D2STGNN.decouple.estimation_gate import EstimationGate
|
||||
|
||||
class DecoupleLayer(nn.Module):
|
||||
def __init__(self, hidden_dim, fk_dim, args):
|
||||
super().__init__()
|
||||
self.est_gate = EstimationGate(node_emb_dim=args['node_hidden'], time_emb_dim=args['time_emb_dim'], hidden_dim=64)
|
||||
# 只传递必要参数,dy_graph会通过**args传递
|
||||
self.dif_layer = DifBlock(hidden_dim, forecast_hidden_dim=fk_dim, **args)
|
||||
self.inh_layer = InhBlock(hidden_dim, forecast_hidden_dim=fk_dim, **args)
|
||||
|
||||
def forward(self, x, dyn_graph, sta_graph=None, node_u=None, node_d=None, t_in_day=None, t_in_week=None):
|
||||
gated_x = self.est_gate(node_u, node_d, t_in_day, t_in_week, x)
|
||||
dif_back, dif_hidden = self.dif_layer(x, gated_x, dyn_graph, sta_graph)
|
||||
inh_back, inh_hidden = self.inh_layer(dif_back)
|
||||
return inh_back, dif_hidden, inh_hidden
|
||||
|
||||
class D2STGNN(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.args = args # 保存args用于forward方法
|
||||
self.num_nodes = args['num_nodes']
|
||||
self.num_layers = args['num_layers']
|
||||
self.hidden_dim = args['num_hidden']
|
||||
self.forecast_dim = args['forecast_dim']
|
||||
self.output_hidden = args['output_hidden']
|
||||
self.output_dim = args['output_dim']
|
||||
self.in_feat = args['input_dim']
|
||||
|
||||
self.embedding = nn.Linear(self.in_feat, self.hidden_dim)
|
||||
self.T_i_D_emb = nn.Parameter(torch.empty(args.get('num_timesteps_in_day',288), args['time_emb_dim']))
|
||||
self.D_i_W_emb = nn.Parameter(torch.empty(7, args['time_emb_dim']))
|
||||
self.node_u = nn.Parameter(torch.empty(self.num_nodes, args['node_hidden']))
|
||||
self.node_d = nn.Parameter(torch.empty(self.num_nodes, args['node_hidden']))
|
||||
|
||||
self.layers = nn.ModuleList([DecoupleLayer(self.hidden_dim, self.forecast_dim, args) for _ in range(self.num_layers)])
|
||||
if args.get('dy_graph', False):
|
||||
self.dynamic_graph_constructor = DynamicGraphConstructor(**args)
|
||||
|
||||
self.out_fc1 = nn.Linear(self.forecast_dim, self.output_hidden)
|
||||
self.out_fc2 = nn.Linear(self.output_hidden, args['gap'] * args['output_dim'])
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self):
|
||||
for p in [self.node_u, self.node_d, self.T_i_D_emb, self.D_i_W_emb]:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def _prepare_inputs(self, x):
|
||||
node_u, node_d = self.node_u, self.node_d
|
||||
t_in_day = self.T_i_D_emb[(x[:, :, :, -2]*self.T_i_D_emb.size(0)).long()]
|
||||
t_in_week = self.D_i_W_emb[x[:, :, :, -1].long()]
|
||||
return x[:, :, :, :-2], node_u, node_d, t_in_day, t_in_week
|
||||
|
||||
def _graph_constructor(self, node_u, node_d, x, t_in_day, t_in_week):
|
||||
# 只生成动态图,去除静态图
|
||||
dyn_graph = self.dynamic_graph_constructor(node_u=node_u, node_d=node_d, history_data=x, time_in_day_feat=t_in_day, day_in_week_feat=t_in_week) if hasattr(self, 'dynamic_graph_constructor') else []
|
||||
return [], dyn_graph
|
||||
|
||||
def forward(self, x):
|
||||
x, node_u, node_d, t_in_day, t_in_week = self._prepare_inputs(x)
|
||||
sta_graph, dyn_graph = self._graph_constructor(node_u, node_d, x, t_in_day, t_in_week)
|
||||
x = self.embedding(x)
|
||||
|
||||
dif_hidden_list, inh_hidden_list = [], []
|
||||
backcast = x
|
||||
for layer in self.layers:
|
||||
backcast, dif_hidden, inh_hidden = layer(backcast, dyn_graph, sta_graph, node_u, node_d, t_in_day, t_in_week)
|
||||
dif_hidden_list.append(dif_hidden)
|
||||
inh_hidden_list.append(inh_hidden)
|
||||
|
||||
forecast_hidden = sum(dif_hidden_list) + sum(inh_hidden_list)
|
||||
# 调整输出形状,使其与标签匹配
|
||||
forecast = self.out_fc1(F.relu(forecast_hidden))
|
||||
forecast = F.relu(forecast)
|
||||
forecast = self.out_fc2(forecast)
|
||||
# 确保输出维度正确
|
||||
if forecast.size(-1) != self.args['output_dim']:
|
||||
forecast = forecast[..., :self.args['output_dim']]
|
||||
# 确保时间步长正确
|
||||
if forecast.size(1) != self.args['horizon']:
|
||||
# 如果时间步长不足,进行插值或重复
|
||||
forecast = forecast.repeat(1, self.args['horizon'] // forecast.size(1) + 1, 1, 1)[:, :self.args['horizon'], :, :]
|
||||
return forecast
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class EstimationGate(nn.Module):
|
||||
"""The estimation gate module."""
|
||||
|
||||
def __init__(self, node_emb_dim, time_emb_dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.fully_connected_layer_1 = nn.Linear(2 * node_emb_dim + time_emb_dim * 2, hidden_dim)
|
||||
self.activation = nn.ReLU()
|
||||
self.fully_connected_layer_2 = nn.Linear(hidden_dim, 1)
|
||||
|
||||
def forward(self, node_embedding_u, node_embedding_d, time_in_day_feat, day_in_week_feat, history_data):
|
||||
"""Generate gate value in (0, 1) based on current node and time step embeddings to roughly estimating the proportion of the two hidden time series."""
|
||||
|
||||
batch_size, seq_length, _, _ = time_in_day_feat.shape
|
||||
estimation_gate_feat = torch.cat([time_in_day_feat, day_in_week_feat, node_embedding_u.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_length, -1, -1), node_embedding_d.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_length, -1, -1)], dim=-1)
|
||||
hidden = self.fully_connected_layer_1(estimation_gate_feat)
|
||||
hidden = self.activation(hidden)
|
||||
# activation
|
||||
estimation_gate = torch.sigmoid(self.fully_connected_layer_2(hidden))[:, -history_data.shape[1]:, :, :]
|
||||
history_data = history_data * estimation_gate
|
||||
return history_data
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
import torch.nn as nn
|
||||
|
||||
|
||||
class ResidualDecomp(nn.Module):
|
||||
"""Residual decomposition."""
|
||||
|
||||
def __init__(self, input_shape):
|
||||
super().__init__()
|
||||
self.ln = nn.LayerNorm(input_shape[-1])
|
||||
self.ac = nn.ReLU()
|
||||
|
||||
def forward(self, x, y):
|
||||
u = x - self.ac(y)
|
||||
u = self.ln(u)
|
||||
return u
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
import torch.nn as nn
|
||||
|
||||
from model.D2STGNN.diffusion_block.forecast import Forecast
|
||||
from model.D2STGNN.diffusion_block.dif_model import STLocalizedConv
|
||||
from model.D2STGNN.decouple.residual_decomp import ResidualDecomp
|
||||
|
||||
|
||||
class DifBlock(nn.Module):
|
||||
def __init__(self, hidden_dim, forecast_hidden_dim=256, dy_graph=None, **model_args):
|
||||
"""Diffusion block
|
||||
|
||||
Args:
|
||||
hidden_dim (int): hidden dimension.
|
||||
forecast_hidden_dim (int, optional): forecast branch hidden dimension. Defaults to 256.
|
||||
dy_graph (bool, optional): if use dynamic graph. Defaults to None.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
# diffusion model - 只保留动态图
|
||||
self.localized_st_conv = STLocalizedConv(hidden_dim, dy_graph=dy_graph, **model_args)
|
||||
|
||||
# forecast
|
||||
self.forecast_branch = Forecast(hidden_dim, forecast_hidden_dim=forecast_hidden_dim, **model_args)
|
||||
# backcast
|
||||
self.backcast_branch = nn.Linear(hidden_dim, hidden_dim)
|
||||
# esidual decomposition
|
||||
self.residual_decompose = ResidualDecomp([-1, -1, -1, hidden_dim])
|
||||
|
||||
def forward(self, history_data, gated_history_data, dynamic_graph, static_graph=None):
|
||||
"""Diffusion block, containing the diffusion model, forecast branch, backcast branch, and the residual decomposition link.
|
||||
|
||||
Args:
|
||||
history_data (torch.Tensor): history data with shape [batch_size, seq_len, num_nodes, hidden_dim]
|
||||
gated_history_data (torch.Tensor): gated history data with shape [batch_size, seq_len, num_nodes, hidden_dim]
|
||||
dynamic_graph (list): dynamic graphs.
|
||||
static_graph (list, optional): static graphs (未使用).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the output after the decoupling mechanism (backcast branch and the residual link), which should be fed to the inherent model.
|
||||
Shape: [batch_size, seq_len', num_nodes, hidden_dim]. Kindly note that after the st conv, the sequence will be shorter.
|
||||
torch.Tensor: the output of the forecast branch, which will be used to make final prediction.
|
||||
Shape: [batch_size, seq_len'', num_nodes, forecast_hidden_dim]. seq_len'' = future_len / gap.
|
||||
In order to reduce the error accumulation in the AR forecasting strategy, we let each hidden state generate the prediction of gap points, instead of a single point.
|
||||
"""
|
||||
|
||||
# diffusion model - 只使用动态图
|
||||
hidden_states_dif = self.localized_st_conv(gated_history_data, dynamic_graph, static_graph)
|
||||
# forecast branch: use the localized st conv to predict future hidden states.
|
||||
forecast_hidden = self.forecast_branch(gated_history_data, hidden_states_dif, self.localized_st_conv, dynamic_graph, static_graph)
|
||||
# backcast branch: use FC layer to do backcast
|
||||
backcast_seq = self.backcast_branch(hidden_states_dif)
|
||||
# residual decomposition: remove the learned knowledge from input data
|
||||
history_data = history_data[:, -backcast_seq.shape[1]:, :, :]
|
||||
backcast_seq_res = self.residual_decompose(history_data, backcast_seq)
|
||||
|
||||
return backcast_seq_res, forecast_hidden
|
||||
|
|
@ -0,0 +1,128 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class STLocalizedConv(nn.Module):
|
||||
def __init__(self, hidden_dim, dy_graph=None, **model_args):
|
||||
super().__init__()
|
||||
# gated temporal conv
|
||||
self.k_s = model_args['k_s']
|
||||
self.k_t = model_args['k_t']
|
||||
self.hidden_dim = hidden_dim
|
||||
|
||||
# graph conv - 只保留动态图
|
||||
self.use_dynamic_hidden_graph = dy_graph
|
||||
|
||||
# 只考虑动态图
|
||||
self.support_len = int(dy_graph) if dy_graph is not None else 0
|
||||
|
||||
# num_matric = 1 (X_0) + dynamic graphs count
|
||||
self.num_matric = 1 + self.support_len
|
||||
self.dropout = nn.Dropout(model_args['dropout'])
|
||||
|
||||
self.fc_list_updt = nn.Linear(
|
||||
self.k_t * hidden_dim, self.k_t * hidden_dim, bias=False)
|
||||
self.gcn_updt = nn.Linear(
|
||||
self.hidden_dim*self.num_matric, self.hidden_dim)
|
||||
|
||||
# others
|
||||
self.bn = nn.BatchNorm2d(self.hidden_dim)
|
||||
self.activation = nn.ReLU()
|
||||
|
||||
def gconv(self, support, X_k, X_0):
|
||||
out = [X_0]
|
||||
batch_size, seq_len, _, hidden_dim = X_0.shape
|
||||
|
||||
for graph in support:
|
||||
# 确保graph的形状与X_k匹配
|
||||
if len(graph.shape) == 3: # 动态图,形状为 [B, N, K*N]
|
||||
# 复制graph以匹配seq_len维度
|
||||
graph = graph.unsqueeze(1).repeat(1, seq_len, 1, 1) # [B, L, N, K*N]
|
||||
elif len(graph.shape) == 2: # 静态图,形状为 [N, K*N]
|
||||
graph = graph.unsqueeze(0).unsqueeze(1).repeat(batch_size, seq_len, 1, 1) # [B, L, N, K*N]
|
||||
|
||||
# 确保X_k的形状正确
|
||||
if X_k.dim() == 4: # [B, L, K*N, D]
|
||||
# 进行矩阵乘法:[B, L, N, K*N] x [B, L, K*N, D] -> [B, L, N, D]
|
||||
H_k = torch.matmul(graph, X_k)
|
||||
else:
|
||||
H_k = torch.matmul(graph, X_k.unsqueeze(1))
|
||||
H_k = H_k.squeeze(1)
|
||||
|
||||
out.append(H_k)
|
||||
|
||||
# 拼接所有结果
|
||||
out = torch.cat(out, dim=-1)
|
||||
|
||||
# 动态调整线性层的输入维度
|
||||
if out.shape[-1] != self.gcn_updt.in_features:
|
||||
# 创建新的线性层,匹配当前的输入维度
|
||||
new_gcn_updt = nn.Linear(out.shape[-1], self.hidden_dim).to(out.device)
|
||||
# 复制原有参数(如果可能的话)
|
||||
with torch.no_grad():
|
||||
min_dim = min(out.shape[-1], self.gcn_updt.in_features)
|
||||
new_gcn_updt.weight[:, :min_dim] = self.gcn_updt.weight[:, :min_dim]
|
||||
if new_gcn_updt.bias is not None and self.gcn_updt.bias is not None:
|
||||
new_gcn_updt.bias = self.gcn_updt.bias
|
||||
self.gcn_updt = new_gcn_updt
|
||||
|
||||
out = self.gcn_updt(out)
|
||||
out = self.dropout(out)
|
||||
return out
|
||||
|
||||
def get_graph(self, support):
|
||||
# Only used in static including static hidden graph and predefined graph, but not used for dynamic graph.
|
||||
if support is None or len(support) == 0:
|
||||
return []
|
||||
|
||||
graph_ordered = []
|
||||
mask = 1 - torch.eye(support[0].shape[0]).to(support[0].device)
|
||||
for graph in support:
|
||||
k_1_order = graph # 1 order
|
||||
graph_ordered.append(k_1_order * mask)
|
||||
# e.g., order = 3, k=[2, 3]; order = 2, k=[2]
|
||||
for k in range(2, self.k_s+1):
|
||||
k_1_order = torch.matmul(graph, k_1_order)
|
||||
graph_ordered.append(k_1_order * mask)
|
||||
# get st localed graph
|
||||
st_local_graph = []
|
||||
for graph in graph_ordered:
|
||||
graph = graph.unsqueeze(-2).expand(-1, self.k_t, -1)
|
||||
graph = graph.reshape(
|
||||
graph.shape[0], graph.shape[1] * graph.shape[2])
|
||||
# [num_nodes, kernel_size x num_nodes]
|
||||
st_local_graph.append(graph)
|
||||
# [order, num_nodes, kernel_size x num_nodes]
|
||||
return st_local_graph
|
||||
|
||||
def forward(self, X, dynamic_graph, static_graph=None):
|
||||
# X: [bs, seq, nodes, feat]
|
||||
# [bs, seq, num_nodes, ks, num_feat]
|
||||
X = X.unfold(1, self.k_t, 1).permute(0, 1, 2, 4, 3)
|
||||
# seq_len is changing
|
||||
batch_size, seq_len, num_nodes, kernel_size, num_feat = X.shape
|
||||
|
||||
# support - 只保留动态图
|
||||
support = []
|
||||
if self.use_dynamic_hidden_graph and dynamic_graph:
|
||||
# k_order is caled in dynamic_graph_constructor component
|
||||
support = support + dynamic_graph
|
||||
|
||||
# parallelize
|
||||
X = X.reshape(batch_size, seq_len, num_nodes, kernel_size * num_feat)
|
||||
# batch_size, seq_len, num_nodes, kernel_size * hidden_dim
|
||||
out = self.fc_list_updt(X)
|
||||
out = self.activation(out)
|
||||
out = out.view(batch_size, seq_len, num_nodes, kernel_size, num_feat)
|
||||
X_0 = torch.mean(out, dim=-2)
|
||||
# batch_size, seq_len, kernel_size x num_nodes, hidden_dim
|
||||
X_k = out.transpose(-3, -2).reshape(batch_size,
|
||||
seq_len, kernel_size*num_nodes, num_feat)
|
||||
|
||||
# 如果support为空,直接返回X_0
|
||||
if len(support) == 0:
|
||||
return X_0
|
||||
|
||||
# Nx3N 3NxD -> NxD: batch_size, seq_len, num_nodes, hidden_dim
|
||||
hidden = self.gconv(support, X_k, X_0)
|
||||
return hidden
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Forecast(nn.Module):
|
||||
def __init__(self, hidden_dim, forecast_hidden_dim=None, **model_args):
|
||||
super().__init__()
|
||||
self.k_t = model_args['k_t']
|
||||
self.output_seq_len = model_args['horizon'] # 使用horizon作为目标序列长度
|
||||
self.forecast_fc = nn.Linear(hidden_dim, forecast_hidden_dim)
|
||||
self.model_args = model_args
|
||||
|
||||
def forward(self, gated_history_data, hidden_states_dif, localized_st_conv, dynamic_graph, static_graph):
|
||||
predict = []
|
||||
history = gated_history_data
|
||||
predict.append(hidden_states_dif[:, -1, :, :].unsqueeze(1))
|
||||
for _ in range(int(self.output_seq_len / self.model_args['gap'])-1):
|
||||
_1 = predict[-self.k_t:]
|
||||
if len(_1) < self.k_t:
|
||||
sub = self.k_t - len(_1)
|
||||
_2 = history[:, -sub:, :, :]
|
||||
_1 = torch.cat([_2] + _1, dim=1)
|
||||
else:
|
||||
_1 = torch.cat(_1, dim=1)
|
||||
predict.append(localized_st_conv(_1, dynamic_graph, static_graph))
|
||||
predict = torch.cat(predict, dim=1)
|
||||
predict = self.forecast_fc(predict)
|
||||
return predict
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
import torch.nn as nn
|
||||
|
||||
from model.D2STGNN.dynamic_graph_conv.utils.distance import DistanceFunction
|
||||
from model.D2STGNN.dynamic_graph_conv.utils.mask import Mask
|
||||
|
||||
from model.D2STGNN.dynamic_graph_conv.utils.normalizer import Normalizer, MultiOrder
|
||||
|
||||
|
||||
class DynamicGraphConstructor(nn.Module):
|
||||
def __init__(self, **model_args):
|
||||
super().__init__()
|
||||
# model args
|
||||
self.k_s = model_args['k_s'] # spatial order
|
||||
self.k_t = model_args['k_t'] # temporal kernel size
|
||||
# hidden dimension of
|
||||
self.hidden_dim = model_args['num_hidden']
|
||||
# trainable node embedding dimension
|
||||
self.node_dim = model_args['node_hidden']
|
||||
|
||||
self.distance_function = DistanceFunction(**model_args)
|
||||
self.mask = Mask(**model_args)
|
||||
self.normalizer = Normalizer()
|
||||
self.multi_order = MultiOrder(order=self.k_s)
|
||||
|
||||
def st_localization(self, graph_ordered):
|
||||
st_local_graph = []
|
||||
for modality_i in graph_ordered:
|
||||
for k_order_graph in modality_i:
|
||||
k_order_graph = k_order_graph.unsqueeze(
|
||||
-2).expand(-1, -1, self.k_t, -1)
|
||||
k_order_graph = k_order_graph.reshape(
|
||||
k_order_graph.shape[0], k_order_graph.shape[1], k_order_graph.shape[2] * k_order_graph.shape[3])
|
||||
st_local_graph.append(k_order_graph)
|
||||
return st_local_graph
|
||||
|
||||
def forward(self, **inputs):
|
||||
"""Dynamic graph learning module.
|
||||
|
||||
Args:
|
||||
history_data (torch.Tensor): input data with shape (B, L, N, D)
|
||||
node_embedding_u (torch.Parameter): node embedding E_u
|
||||
node_embedding_d (torch.Parameter): node embedding E_d
|
||||
time_in_day_feat (torch.Parameter): time embedding T_D
|
||||
day_in_week_feat (torch.Parameter): time embedding T_W
|
||||
|
||||
Returns:
|
||||
list: dynamic graphs
|
||||
"""
|
||||
|
||||
X = inputs['history_data']
|
||||
E_d = inputs['node_d'] # 参数名改为node_d
|
||||
E_u = inputs['node_u'] # 参数名改为node_u
|
||||
T_D = inputs['time_in_day_feat']
|
||||
D_W = inputs['day_in_week_feat']
|
||||
# distance calculation
|
||||
dist_mx = self.distance_function(X, E_d, E_u, T_D, D_W)
|
||||
# mask
|
||||
dist_mx = self.mask(dist_mx)
|
||||
# normalization
|
||||
dist_mx = self.normalizer(dist_mx)
|
||||
# multi order
|
||||
mul_mx = self.multi_order(dist_mx)
|
||||
# spatial temporal localization
|
||||
dynamic_graphs = self.st_localization(mul_mx)
|
||||
|
||||
return dynamic_graphs
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class DistanceFunction(nn.Module):
|
||||
def __init__(self, **model_args):
|
||||
super().__init__()
|
||||
# attributes
|
||||
self.hidden_dim = model_args['num_hidden']
|
||||
self.node_dim = model_args['node_hidden']
|
||||
self.time_slot_emb_dim = self.hidden_dim
|
||||
self.input_seq_len = model_args['seq_len']
|
||||
# Time Series Feature Extraction
|
||||
self.dropout = nn.Dropout(model_args['dropout'])
|
||||
self.fc_ts_emb1 = nn.Linear(self.input_seq_len, self.hidden_dim * 2)
|
||||
self.fc_ts_emb2 = nn.Linear(self.hidden_dim * 2, self.hidden_dim)
|
||||
self.ts_feat_dim= self.hidden_dim
|
||||
# Time Slot Embedding Extraction
|
||||
self.time_slot_embedding = nn.Linear(model_args['time_emb_dim'], self.time_slot_emb_dim)
|
||||
# Distance Score
|
||||
self.all_feat_dim = self.ts_feat_dim + self.node_dim + model_args['time_emb_dim']*2
|
||||
self.WQ = nn.Linear(self.all_feat_dim, self.hidden_dim, bias=False)
|
||||
self.WK = nn.Linear(self.all_feat_dim, self.hidden_dim, bias=False)
|
||||
self.bn = nn.BatchNorm1d(self.hidden_dim*2)
|
||||
|
||||
def reset_parameters(self):
|
||||
# 初始化所有线性层的参数
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def forward(self, X, E_d, E_u, T_D, D_W):
|
||||
# last pooling
|
||||
T_D = T_D[:, -1, :, :]
|
||||
D_W = D_W[:, -1, :, :]
|
||||
# dynamic information
|
||||
X = X[:, :, :, 0].transpose(1, 2).contiguous() # X->[batch_size, seq_len, num_nodes]->[batch_size, num_nodes, seq_len]
|
||||
[batch_size, num_nodes, seq_len] = X.shape
|
||||
X = X.view(batch_size * num_nodes, seq_len)
|
||||
dy_feat = self.fc_ts_emb2(self.dropout(self.bn(F.relu(self.fc_ts_emb1(X))))) # [batchsize, num_nodes, hidden_dim]
|
||||
dy_feat = dy_feat.view(batch_size, num_nodes, -1)
|
||||
# node embedding
|
||||
emb1 = E_d.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
emb2 = E_u.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
# distance calculation
|
||||
X1 = torch.cat([dy_feat, T_D, D_W, emb1], dim=-1) # hidden state for calculating distance
|
||||
X2 = torch.cat([dy_feat, T_D, D_W, emb2], dim=-1) # hidden state for calculating distance
|
||||
X = [X1, X2]
|
||||
adjacent_list = []
|
||||
for _ in X:
|
||||
Q = self.WQ(_)
|
||||
K = self.WK(_)
|
||||
QKT = torch.bmm(Q, K.transpose(-1, -2)) / math.sqrt(self.hidden_dim)
|
||||
W = torch.softmax(QKT, dim=-1)
|
||||
adjacent_list.append(W)
|
||||
return adjacent_list
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Mask(nn.Module):
|
||||
def __init__(self, **model_args):
|
||||
super().__init__()
|
||||
self.mask = model_args.get('adjs', None) # 允许adjs为None
|
||||
|
||||
def _mask(self, index, adj):
|
||||
if self.mask is None or len(self.mask) == 0:
|
||||
# 如果没有预定义的邻接矩阵,直接返回原始的adj
|
||||
return adj
|
||||
else:
|
||||
mask = self.mask[index] + torch.ones_like(self.mask[index]) * 1e-7
|
||||
return mask.to(adj.device) * adj
|
||||
|
||||
def forward(self, adj):
|
||||
result = []
|
||||
for index, _ in enumerate(adj):
|
||||
result.append(self._mask(index, _))
|
||||
return result
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def remove_nan_inf(x):
|
||||
"""移除张量中的nan和inf值"""
|
||||
x = torch.where(torch.isnan(x) | torch.isinf(x), torch.zeros_like(x), x)
|
||||
return x
|
||||
|
||||
|
||||
class Normalizer(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def _norm(self, graph):
|
||||
degree = torch.sum(graph, dim=2)
|
||||
degree = remove_nan_inf(1 / degree)
|
||||
degree = torch.diag_embed(degree)
|
||||
normed_graph = torch.bmm(degree, graph)
|
||||
return normed_graph
|
||||
|
||||
def forward(self, adj):
|
||||
return [self._norm(_) for _ in adj]
|
||||
|
||||
class MultiOrder(nn.Module):
|
||||
def __init__(self, order=2):
|
||||
super().__init__()
|
||||
self.order = order
|
||||
|
||||
def _multi_order(self, graph):
|
||||
graph_ordered = []
|
||||
k_1_order = graph # 1 order
|
||||
mask = torch.eye(graph.shape[1]).to(graph.device)
|
||||
mask = 1 - mask
|
||||
graph_ordered.append(k_1_order * mask)
|
||||
for k in range(2, self.order+1): # e.g., order = 3, k=[2, 3]; order = 2, k=[2]
|
||||
k_1_order = torch.matmul(k_1_order, graph)
|
||||
graph_ordered.append(k_1_order * mask)
|
||||
return graph_ordered
|
||||
|
||||
def forward(self, adj):
|
||||
return [self._multi_order(_) for _ in adj]
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Forecast(nn.Module):
|
||||
def __init__(self, hidden_dim, fk_dim, **model_args):
|
||||
super().__init__()
|
||||
self.output_seq_len = model_args['seq_len']
|
||||
self.model_args = model_args
|
||||
|
||||
self.forecast_fc = nn.Linear(hidden_dim, fk_dim)
|
||||
|
||||
def forward(self, X, RNN_H, Z, transformer_layer, rnn_layer, pe):
|
||||
[batch_size, _, num_nodes, num_feat] = X.shape
|
||||
|
||||
predict = [Z[-1, :, :].unsqueeze(0)]
|
||||
for _ in range(int(self.output_seq_len / self.model_args['gap'])-1):
|
||||
# RNN
|
||||
_gru = rnn_layer.gru_cell(predict[-1][0], RNN_H[-1]).unsqueeze(0)
|
||||
RNN_H = torch.cat([RNN_H, _gru], dim=0)
|
||||
# Positional Encoding
|
||||
if pe is not None:
|
||||
RNN_H = pe(RNN_H)
|
||||
# Transformer
|
||||
_Z = transformer_layer(_gru, K=RNN_H, V=RNN_H)
|
||||
predict.append(_Z)
|
||||
|
||||
predict = torch.cat(predict, dim=0)
|
||||
predict = predict.reshape(-1, batch_size, num_nodes, num_feat)
|
||||
predict = predict.transpose(0, 1)
|
||||
predict = self.forecast_fc(predict)
|
||||
return predict
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from model.D2STGNN.decouple.residual_decomp import ResidualDecomp
|
||||
from model.D2STGNN.inherent_block.inh_model import RNNLayer, TransformerLayer
|
||||
from model.D2STGNN.inherent_block.forecast import Forecast
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model, dropout=None, max_len: int = 5000):
|
||||
super().__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
position = torch.arange(max_len).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
||||
pe = torch.zeros(max_len, 1, d_model)
|
||||
pe[:, 0, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 0, 1::2] = torch.cos(position * div_term)
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, X):
|
||||
X = X + self.pe[:X.size(0)]
|
||||
X = self.dropout(X)
|
||||
return X
|
||||
|
||||
|
||||
class InhBlock(nn.Module):
|
||||
def __init__(self, hidden_dim, num_heads=4, bias=True, forecast_hidden_dim=256, **model_args):
|
||||
"""Inherent block
|
||||
|
||||
Args:
|
||||
hidden_dim (int): hidden dimension
|
||||
num_heads (int, optional): number of heads of MSA. Defaults to 4.
|
||||
bias (bool, optional): if use bias. Defaults to True.
|
||||
forecast_hidden_dim (int, optional): forecast branch hidden dimension. Defaults to 256.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_feat = hidden_dim
|
||||
self.hidden_dim = hidden_dim
|
||||
|
||||
# inherent model
|
||||
self.pos_encoder = PositionalEncoding(hidden_dim, model_args['dropout'])
|
||||
self.rnn_layer = RNNLayer(hidden_dim, model_args['dropout'])
|
||||
self.transformer_layer = TransformerLayer(hidden_dim, num_heads, model_args['dropout'], bias)
|
||||
|
||||
# forecast branch
|
||||
self.forecast_block = Forecast(hidden_dim, forecast_hidden_dim, **model_args)
|
||||
# backcast branch
|
||||
self.backcast_fc = nn.Linear(hidden_dim, hidden_dim)
|
||||
# residual decomposition
|
||||
self.residual_decompose = ResidualDecomp([-1, -1, -1, hidden_dim])
|
||||
|
||||
def forward(self, hidden_inherent_signal):
|
||||
"""Inherent block, containing the inherent model, forecast branch, backcast branch, and the residual decomposition link.
|
||||
|
||||
Args:
|
||||
hidden_inherent_signal (torch.Tensor): hidden inherent signal with shape [batch_size, seq_len, num_nodes, num_feat].
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the output after the decoupling mechanism (backcast branch and the residual link), which should be fed to the next decouple layer.
|
||||
Shape: [batch_size, seq_len, num_nodes, hidden_dim].
|
||||
torch.Tensor: the output of the forecast branch, which will be used to make final prediction.
|
||||
Shape: [batch_size, seq_len'', num_nodes, forecast_hidden_dim]. seq_len'' = future_len / gap.
|
||||
In order to reduce the error accumulation in the AR forecasting strategy, we let each hidden state generate the prediction of gap points, instead of a single point.
|
||||
"""
|
||||
|
||||
[batch_size, seq_len, num_nodes, num_feat] = hidden_inherent_signal.shape
|
||||
# inherent model
|
||||
## rnn
|
||||
hidden_states_rnn = self.rnn_layer(hidden_inherent_signal)
|
||||
## pe
|
||||
hidden_states_rnn = self.pos_encoder(hidden_states_rnn)
|
||||
## MSA
|
||||
hidden_states_inh = self.transformer_layer(hidden_states_rnn, hidden_states_rnn, hidden_states_rnn)
|
||||
|
||||
# forecast branch
|
||||
forecast_hidden = self.forecast_block(hidden_inherent_signal, hidden_states_rnn, hidden_states_inh, self.transformer_layer, self.rnn_layer, self.pos_encoder)
|
||||
|
||||
# backcast branch
|
||||
hidden_states_inh = hidden_states_inh.reshape(seq_len, batch_size, num_nodes, num_feat)
|
||||
hidden_states_inh = hidden_states_inh.transpose(0, 1)
|
||||
backcast_seq = self.backcast_fc(hidden_states_inh)
|
||||
backcast_seq_res= self.residual_decompose(hidden_inherent_signal, backcast_seq)
|
||||
|
||||
return backcast_seq_res, forecast_hidden
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
import torch as th
|
||||
import torch.nn as nn
|
||||
from torch.nn import MultiheadAttention
|
||||
|
||||
|
||||
class RNNLayer(nn.Module):
|
||||
def __init__(self, hidden_dim, dropout=None):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.gru_cell = nn.GRUCell(hidden_dim, hidden_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, X):
|
||||
[batch_size, seq_len, num_nodes, hidden_dim] = X.shape
|
||||
X = X.transpose(1, 2).reshape(batch_size * num_nodes, seq_len, hidden_dim)
|
||||
hx = th.zeros_like(X[:, 0, :])
|
||||
output = []
|
||||
for _ in range(X.shape[1]):
|
||||
hx = self.gru_cell(X[:, _, :], hx)
|
||||
output.append(hx)
|
||||
output = th.stack(output, dim=0)
|
||||
output = self.dropout(output)
|
||||
return output
|
||||
|
||||
|
||||
class TransformerLayer(nn.Module):
|
||||
def __init__(self, hidden_dim, num_heads=4, dropout=None, bias=True):
|
||||
super().__init__()
|
||||
self.multi_head_self_attention = MultiheadAttention(hidden_dim, num_heads, dropout=dropout, bias=bias)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, X, K, V):
|
||||
hidden_states_MSA = self.multi_head_self_attention(X, K, V)[0]
|
||||
hidden_states_MSA = self.dropout(hidden_states_MSA)
|
||||
return hidden_states_MSA
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[
|
||||
{
|
||||
"name": "D2STGNN",
|
||||
"module": "model.D2STGNN.D2STGNN",
|
||||
"entry": "D2STGNN"
|
||||
}
|
||||
]
|
||||
45
train.py
45
train.py
|
|
@ -6,14 +6,15 @@ import utils.initializer as init
|
|||
from dataloader.loader_selector import get_dataloader
|
||||
from trainer.trainer_selector import select_trainer
|
||||
|
||||
|
||||
def read_config(config_path):
|
||||
with open(config_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
|
||||
# 全局配置
|
||||
device = "cpu" # 指定设备为cuda:0
|
||||
seed = 2023 # 随机种子
|
||||
epochs = 1 # 训练轮数
|
||||
device = "cpu" # 指定设备为cuda:0
|
||||
seed = 2023 # 随机种子
|
||||
epochs = 1 # 训练轮数
|
||||
|
||||
# 拷贝项
|
||||
config["basic"]["device"] = device
|
||||
|
|
@ -23,6 +24,7 @@ def read_config(config_path):
|
|||
config["train"]["epochs"] = epochs
|
||||
return config
|
||||
|
||||
|
||||
def run(config):
|
||||
init.init_seed(config["basic"]["seed"])
|
||||
model = init.init_model(config)
|
||||
|
|
@ -34,10 +36,15 @@ def run(config):
|
|||
init.create_logs(config)
|
||||
trainer = select_trainer(
|
||||
model,
|
||||
loss, optimizer,
|
||||
train_loader, val_loader, test_loader, scaler,
|
||||
loss,
|
||||
optimizer,
|
||||
train_loader,
|
||||
val_loader,
|
||||
test_loader,
|
||||
scaler,
|
||||
config,
|
||||
lr_scheduler, extra_data,
|
||||
lr_scheduler,
|
||||
extra_data,
|
||||
)
|
||||
|
||||
# 开始训练
|
||||
|
|
@ -54,17 +61,20 @@ def run(config):
|
|||
)
|
||||
trainer.test(
|
||||
model.to(config["basic"]["device"]),
|
||||
trainer.args, test_loader, scaler,
|
||||
trainer.args,
|
||||
test_loader,
|
||||
scaler,
|
||||
trainer.logger,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported mode: {config['basic']['mode']}")
|
||||
|
||||
def main(model, data, debug=False):
|
||||
|
||||
|
||||
def main(model_list, data, debug=False):
|
||||
# 我的调试开关,不做测试就填 str(False)
|
||||
# os.environ["TRY"] = str(False)
|
||||
os.environ["TRY"] = str(debug)
|
||||
|
||||
|
||||
for model in model_list:
|
||||
for dataset in data:
|
||||
config_path = f"./config/{model}/{dataset}.yaml"
|
||||
|
|
@ -77,22 +87,25 @@ def main(model, data, debug=False):
|
|||
except Exception as e:
|
||||
import traceback
|
||||
import sys, traceback
|
||||
|
||||
tb_lines = traceback.format_exc().splitlines()
|
||||
# 如果不是AssertionError,才打印完整traceback
|
||||
if not tb_lines[-1].startswith("AssertionError"):
|
||||
traceback.print_exc()
|
||||
print(f"\n===== {model} on {dataset} failed with error: {e} =====\n")
|
||||
print(
|
||||
f"\n===== {model} on {dataset} failed with error: {e} =====\n"
|
||||
)
|
||||
else:
|
||||
run(config)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 调试用
|
||||
# model_list = ["iTransformer", "PatchTST", "HI"]
|
||||
model_list = ["STNorm"]
|
||||
model_list = ["D2STGNN"]
|
||||
# model_list = ["PatchTST"]
|
||||
# dataset_list = ["AirQuality"]
|
||||
dataset_list = ["BJTaxi-InFlow", "BJTaxi-OutFlow"]
|
||||
# dataset_list = ["BJTaxi-InFlow", "BJTaxi-OutFlow"]
|
||||
# dataset_list = ["AirQuality", "PEMS-BAY", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"]
|
||||
main(model_list, dataset_list, debug = True)
|
||||
dataset_list = ["BJTaxi-OutFlow"]
|
||||
main(model_list, dataset_list, debug=True)
|
||||
|
|
|
|||
Loading…
Reference in New Issue