impl D2STGNN

This commit is contained in:
czzhangheng 2025-12-20 18:02:01 +08:00
parent 9d3293cef7
commit b46c16815e
23 changed files with 1194 additions and 16 deletions

View File

@ -0,0 +1,60 @@
basic:
dataset: AirQuality
device: cuda:0
mode: train
model: D2STGNN
seed: 2023
data:
batch_size: 64
column_wise: false
days_per_week: 7
horizon: 24
input_dim: 6
lag: 24
normalizer: std
num_nodes: 35
steps_per_day: 24
test_ratio: 0.2
val_ratio: 0.2
model:
num_nodes: 35
num_layers: 4
num_hidden: 32
forecast_dim: 256
output_hidden: 512
output_dim: 6
seq_len: 24
horizon: 24
input_dim: 6
num_timesteps_in_day: 24
time_emb_dim: 10
node_hidden: 10
dy_graph: True
sta_graph: False
gap: 3
k_s: 2
k_t: 3
dropout: 0.1
train:
batch_size: 64
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_step: 1000
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 6
plot: false
real_value: true
weight_decay: 0

View File

@ -0,0 +1,60 @@
basic:
dataset: BJTaxi-InFlow
device: cuda:0
mode: train
model: D2STGNN
seed: 2023
data:
batch_size: 16
column_wise: false
days_per_week: 7
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 1024
steps_per_day: 48
test_ratio: 0.2
val_ratio: 0.2
model:
num_nodes: 1024
num_layers: 4
num_hidden: 32
forecast_dim: 256
output_hidden: 512
output_dim: 1
seq_len: 24
horizon: 24
input_dim: 1
num_timesteps_in_day: 48
time_emb_dim: 10
node_hidden: 10
dy_graph: True
sta_graph: False
gap: 3
k_s: 2
k_t: 3
dropout: 0.1
train:
batch_size: 16
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_step: 1000
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
plot: false
real_value: true
weight_decay: 0

View File

@ -0,0 +1,60 @@
basic:
dataset: BJTaxi-OutFlow
device: cuda:0
mode: train
model: D2STGNN
seed: 2023
data:
batch_size: 16
column_wise: false
days_per_week: 7
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 1024
steps_per_day: 48
test_ratio: 0.2
val_ratio: 0.2
model:
num_nodes: 1024
num_layers: 4
num_hidden: 32
forecast_dim: 256
output_hidden: 512
output_dim: 1
seq_len: 24
horizon: 24
input_dim: 1
num_timesteps_in_day: 48
time_emb_dim: 10
node_hidden: 10
dy_graph: True
sta_graph: False
gap: 3
k_s: 2
k_t: 3
dropout: 0.1
train:
batch_size: 16
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_step: 1000
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
plot: false
real_value: true
weight_decay: 0

View File

@ -0,0 +1,60 @@
basic:
dataset: METR-LA
device: cuda:0
mode: train
model: D2STGNN
seed: 2023
data:
batch_size: 16
column_wise: false
days_per_week: 7
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 207
steps_per_day: 288
test_ratio: 0.2
val_ratio: 0.2
model:
num_nodes: 207
num_layers: 4
num_hidden: 32
forecast_dim: 256
output_hidden: 512
output_dim: 1
seq_len: 24
horizon: 24
input_dim: 1
num_timesteps_in_day: 288
time_emb_dim: 10
node_hidden: 10
dy_graph: True
sta_graph: False
gap: 3
k_s: 2
k_t: 3
dropout: 0.1
train:
batch_size: 16
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_step: 1000
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
plot: false
real_value: true
weight_decay: 0

View File

@ -0,0 +1,60 @@
basic:
dataset: NYCBike-InFlow
device: cuda:0
mode: train
model: D2STGNN
seed: 2023
data:
batch_size: 64
column_wise: false
days_per_week: 7
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 128
steps_per_day: 48
test_ratio: 0.2
val_ratio: 0.2
model:
num_nodes: 128
num_layers: 4
num_hidden: 32
forecast_dim: 256
output_hidden: 512
output_dim: 1
seq_len: 24
horizon: 24
input_dim: 1
num_timesteps_in_day: 48
time_emb_dim: 10
node_hidden: 10
dy_graph: True
sta_graph: False
gap: 3
k_s: 2
k_t: 3
dropout: 0.1
train:
batch_size: 64
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_step: 1000
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
plot: false
real_value: true
weight_decay: 0

View File

@ -0,0 +1,60 @@
basic:
dataset: NYCBike-OutFlow
device: cuda:0
mode: train
model: D2STGNN
seed: 2023
data:
batch_size: 16
column_wise: false
days_per_week: 7
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 128
steps_per_day: 48
test_ratio: 0.2
val_ratio: 0.2
model:
num_nodes: 128
num_layers: 4
num_hidden: 32
forecast_dim: 256
output_hidden: 512
output_dim: 1
seq_len: 24
horizon: 24
input_dim: 1
num_timesteps_in_day: 48
time_emb_dim: 10
node_hidden: 10
dy_graph: True
sta_graph: False
gap: 3
k_s: 2
k_t: 3
dropout: 0.1
train:
batch_size: 16
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_step: 1000
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
plot: false
real_value: true
weight_decay: 0

View File

@ -0,0 +1,60 @@
basic:
dataset: PEMS-BAY
device: cuda:0
mode: train
model: D2STGNN
seed: 2023
data:
batch_size: 16
column_wise: false
days_per_week: 7
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 325
steps_per_day: 288
test_ratio: 0.2
val_ratio: 0.2
model:
num_nodes: 325
num_layers: 4
num_hidden: 32
forecast_dim: 256
output_hidden: 512
output_dim: 1
seq_len: 24
horizon: 24
input_dim: 1
num_timesteps_in_day: 288
time_emb_dim: 10
node_hidden: 10
dy_graph: True
sta_graph: False
gap: 3
k_s: 2
k_t: 3
dropout: 0.1
train:
batch_size: 16
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_step: 1000
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
plot: false
real_value: true
weight_decay: 0

View File

@ -0,0 +1,60 @@
basic:
dataset: SolarEnergy
device: cuda:0
mode: train
model: D2STGNN
seed: 2023
data:
batch_size: 64
column_wise: false
days_per_week: 7
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 137
steps_per_day: 24
test_ratio: 0.2
val_ratio: 0.2
model:
num_nodes: 137
num_layers: 4
num_hidden: 32
forecast_dim: 256
output_hidden: 512
output_dim: 1
seq_len: 24
horizon: 24
input_dim: 1
num_timesteps_in_day: 24
time_emb_dim: 10
node_hidden: 10
dy_graph: True
sta_graph: False
gap: 3
k_s: 2
k_t: 3
dropout: 0.1
train:
batch_size: 64
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_step: 1000
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
plot: false
real_value: true
weight_decay: 0

88
model/D2STGNN/D2STGNN.py Normal file
View File

@ -0,0 +1,88 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.D2STGNN.diffusion_block.dif_block import DifBlock
from model.D2STGNN.inherent_block.inh_block import InhBlock
from model.D2STGNN.dynamic_graph_conv.dy_graph_conv import DynamicGraphConstructor
from model.D2STGNN.decouple.estimation_gate import EstimationGate
class DecoupleLayer(nn.Module):
def __init__(self, hidden_dim, fk_dim, args):
super().__init__()
self.est_gate = EstimationGate(node_emb_dim=args['node_hidden'], time_emb_dim=args['time_emb_dim'], hidden_dim=64)
# 只传递必要参数dy_graph会通过**args传递
self.dif_layer = DifBlock(hidden_dim, forecast_hidden_dim=fk_dim, **args)
self.inh_layer = InhBlock(hidden_dim, forecast_hidden_dim=fk_dim, **args)
def forward(self, x, dyn_graph, sta_graph=None, node_u=None, node_d=None, t_in_day=None, t_in_week=None):
gated_x = self.est_gate(node_u, node_d, t_in_day, t_in_week, x)
dif_back, dif_hidden = self.dif_layer(x, gated_x, dyn_graph, sta_graph)
inh_back, inh_hidden = self.inh_layer(dif_back)
return inh_back, dif_hidden, inh_hidden
class D2STGNN(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args # 保存args用于forward方法
self.num_nodes = args['num_nodes']
self.num_layers = args['num_layers']
self.hidden_dim = args['num_hidden']
self.forecast_dim = args['forecast_dim']
self.output_hidden = args['output_hidden']
self.output_dim = args['output_dim']
self.in_feat = args['input_dim']
self.embedding = nn.Linear(self.in_feat, self.hidden_dim)
self.T_i_D_emb = nn.Parameter(torch.empty(args.get('num_timesteps_in_day',288), args['time_emb_dim']))
self.D_i_W_emb = nn.Parameter(torch.empty(7, args['time_emb_dim']))
self.node_u = nn.Parameter(torch.empty(self.num_nodes, args['node_hidden']))
self.node_d = nn.Parameter(torch.empty(self.num_nodes, args['node_hidden']))
self.layers = nn.ModuleList([DecoupleLayer(self.hidden_dim, self.forecast_dim, args) for _ in range(self.num_layers)])
if args.get('dy_graph', False):
self.dynamic_graph_constructor = DynamicGraphConstructor(**args)
self.out_fc1 = nn.Linear(self.forecast_dim, self.output_hidden)
self.out_fc2 = nn.Linear(self.output_hidden, args['gap'] * args['output_dim'])
self._reset_parameters()
def _reset_parameters(self):
for p in [self.node_u, self.node_d, self.T_i_D_emb, self.D_i_W_emb]:
nn.init.xavier_uniform_(p)
def _prepare_inputs(self, x):
node_u, node_d = self.node_u, self.node_d
t_in_day = self.T_i_D_emb[(x[:, :, :, -2]*self.T_i_D_emb.size(0)).long()]
t_in_week = self.D_i_W_emb[x[:, :, :, -1].long()]
return x[:, :, :, :-2], node_u, node_d, t_in_day, t_in_week
def _graph_constructor(self, node_u, node_d, x, t_in_day, t_in_week):
# 只生成动态图,去除静态图
dyn_graph = self.dynamic_graph_constructor(node_u=node_u, node_d=node_d, history_data=x, time_in_day_feat=t_in_day, day_in_week_feat=t_in_week) if hasattr(self, 'dynamic_graph_constructor') else []
return [], dyn_graph
def forward(self, x):
x, node_u, node_d, t_in_day, t_in_week = self._prepare_inputs(x)
sta_graph, dyn_graph = self._graph_constructor(node_u, node_d, x, t_in_day, t_in_week)
x = self.embedding(x)
dif_hidden_list, inh_hidden_list = [], []
backcast = x
for layer in self.layers:
backcast, dif_hidden, inh_hidden = layer(backcast, dyn_graph, sta_graph, node_u, node_d, t_in_day, t_in_week)
dif_hidden_list.append(dif_hidden)
inh_hidden_list.append(inh_hidden)
forecast_hidden = sum(dif_hidden_list) + sum(inh_hidden_list)
# 调整输出形状,使其与标签匹配
forecast = self.out_fc1(F.relu(forecast_hidden))
forecast = F.relu(forecast)
forecast = self.out_fc2(forecast)
# 确保输出维度正确
if forecast.size(-1) != self.args['output_dim']:
forecast = forecast[..., :self.args['output_dim']]
# 确保时间步长正确
if forecast.size(1) != self.args['horizon']:
# 如果时间步长不足,进行插值或重复
forecast = forecast.repeat(1, self.args['horizon'] // forecast.size(1) + 1, 1, 1)[:, :self.args['horizon'], :, :]
return forecast

View File

@ -0,0 +1,24 @@
import torch
import torch.nn as nn
class EstimationGate(nn.Module):
"""The estimation gate module."""
def __init__(self, node_emb_dim, time_emb_dim, hidden_dim):
super().__init__()
self.fully_connected_layer_1 = nn.Linear(2 * node_emb_dim + time_emb_dim * 2, hidden_dim)
self.activation = nn.ReLU()
self.fully_connected_layer_2 = nn.Linear(hidden_dim, 1)
def forward(self, node_embedding_u, node_embedding_d, time_in_day_feat, day_in_week_feat, history_data):
"""Generate gate value in (0, 1) based on current node and time step embeddings to roughly estimating the proportion of the two hidden time series."""
batch_size, seq_length, _, _ = time_in_day_feat.shape
estimation_gate_feat = torch.cat([time_in_day_feat, day_in_week_feat, node_embedding_u.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_length, -1, -1), node_embedding_d.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_length, -1, -1)], dim=-1)
hidden = self.fully_connected_layer_1(estimation_gate_feat)
hidden = self.activation(hidden)
# activation
estimation_gate = torch.sigmoid(self.fully_connected_layer_2(hidden))[:, -history_data.shape[1]:, :, :]
history_data = history_data * estimation_gate
return history_data

View File

@ -0,0 +1,15 @@
import torch.nn as nn
class ResidualDecomp(nn.Module):
"""Residual decomposition."""
def __init__(self, input_shape):
super().__init__()
self.ln = nn.LayerNorm(input_shape[-1])
self.ac = nn.ReLU()
def forward(self, x, y):
u = x - self.ac(y)
u = self.ln(u)
return u

View File

@ -0,0 +1,56 @@
import torch.nn as nn
from model.D2STGNN.diffusion_block.forecast import Forecast
from model.D2STGNN.diffusion_block.dif_model import STLocalizedConv
from model.D2STGNN.decouple.residual_decomp import ResidualDecomp
class DifBlock(nn.Module):
def __init__(self, hidden_dim, forecast_hidden_dim=256, dy_graph=None, **model_args):
"""Diffusion block
Args:
hidden_dim (int): hidden dimension.
forecast_hidden_dim (int, optional): forecast branch hidden dimension. Defaults to 256.
dy_graph (bool, optional): if use dynamic graph. Defaults to None.
"""
super().__init__()
# diffusion model - 只保留动态图
self.localized_st_conv = STLocalizedConv(hidden_dim, dy_graph=dy_graph, **model_args)
# forecast
self.forecast_branch = Forecast(hidden_dim, forecast_hidden_dim=forecast_hidden_dim, **model_args)
# backcast
self.backcast_branch = nn.Linear(hidden_dim, hidden_dim)
# esidual decomposition
self.residual_decompose = ResidualDecomp([-1, -1, -1, hidden_dim])
def forward(self, history_data, gated_history_data, dynamic_graph, static_graph=None):
"""Diffusion block, containing the diffusion model, forecast branch, backcast branch, and the residual decomposition link.
Args:
history_data (torch.Tensor): history data with shape [batch_size, seq_len, num_nodes, hidden_dim]
gated_history_data (torch.Tensor): gated history data with shape [batch_size, seq_len, num_nodes, hidden_dim]
dynamic_graph (list): dynamic graphs.
static_graph (list, optional): static graphs (未使用).
Returns:
torch.Tensor: the output after the decoupling mechanism (backcast branch and the residual link), which should be fed to the inherent model.
Shape: [batch_size, seq_len', num_nodes, hidden_dim]. Kindly note that after the st conv, the sequence will be shorter.
torch.Tensor: the output of the forecast branch, which will be used to make final prediction.
Shape: [batch_size, seq_len'', num_nodes, forecast_hidden_dim]. seq_len'' = future_len / gap.
In order to reduce the error accumulation in the AR forecasting strategy, we let each hidden state generate the prediction of gap points, instead of a single point.
"""
# diffusion model - 只使用动态图
hidden_states_dif = self.localized_st_conv(gated_history_data, dynamic_graph, static_graph)
# forecast branch: use the localized st conv to predict future hidden states.
forecast_hidden = self.forecast_branch(gated_history_data, hidden_states_dif, self.localized_st_conv, dynamic_graph, static_graph)
# backcast branch: use FC layer to do backcast
backcast_seq = self.backcast_branch(hidden_states_dif)
# residual decomposition: remove the learned knowledge from input data
history_data = history_data[:, -backcast_seq.shape[1]:, :, :]
backcast_seq_res = self.residual_decompose(history_data, backcast_seq)
return backcast_seq_res, forecast_hidden

View File

@ -0,0 +1,128 @@
import torch
import torch.nn as nn
class STLocalizedConv(nn.Module):
def __init__(self, hidden_dim, dy_graph=None, **model_args):
super().__init__()
# gated temporal conv
self.k_s = model_args['k_s']
self.k_t = model_args['k_t']
self.hidden_dim = hidden_dim
# graph conv - 只保留动态图
self.use_dynamic_hidden_graph = dy_graph
# 只考虑动态图
self.support_len = int(dy_graph) if dy_graph is not None else 0
# num_matric = 1 (X_0) + dynamic graphs count
self.num_matric = 1 + self.support_len
self.dropout = nn.Dropout(model_args['dropout'])
self.fc_list_updt = nn.Linear(
self.k_t * hidden_dim, self.k_t * hidden_dim, bias=False)
self.gcn_updt = nn.Linear(
self.hidden_dim*self.num_matric, self.hidden_dim)
# others
self.bn = nn.BatchNorm2d(self.hidden_dim)
self.activation = nn.ReLU()
def gconv(self, support, X_k, X_0):
out = [X_0]
batch_size, seq_len, _, hidden_dim = X_0.shape
for graph in support:
# 确保graph的形状与X_k匹配
if len(graph.shape) == 3: # 动态图,形状为 [B, N, K*N]
# 复制graph以匹配seq_len维度
graph = graph.unsqueeze(1).repeat(1, seq_len, 1, 1) # [B, L, N, K*N]
elif len(graph.shape) == 2: # 静态图,形状为 [N, K*N]
graph = graph.unsqueeze(0).unsqueeze(1).repeat(batch_size, seq_len, 1, 1) # [B, L, N, K*N]
# 确保X_k的形状正确
if X_k.dim() == 4: # [B, L, K*N, D]
# 进行矩阵乘法:[B, L, N, K*N] x [B, L, K*N, D] -> [B, L, N, D]
H_k = torch.matmul(graph, X_k)
else:
H_k = torch.matmul(graph, X_k.unsqueeze(1))
H_k = H_k.squeeze(1)
out.append(H_k)
# 拼接所有结果
out = torch.cat(out, dim=-1)
# 动态调整线性层的输入维度
if out.shape[-1] != self.gcn_updt.in_features:
# 创建新的线性层,匹配当前的输入维度
new_gcn_updt = nn.Linear(out.shape[-1], self.hidden_dim).to(out.device)
# 复制原有参数(如果可能的话)
with torch.no_grad():
min_dim = min(out.shape[-1], self.gcn_updt.in_features)
new_gcn_updt.weight[:, :min_dim] = self.gcn_updt.weight[:, :min_dim]
if new_gcn_updt.bias is not None and self.gcn_updt.bias is not None:
new_gcn_updt.bias = self.gcn_updt.bias
self.gcn_updt = new_gcn_updt
out = self.gcn_updt(out)
out = self.dropout(out)
return out
def get_graph(self, support):
# Only used in static including static hidden graph and predefined graph, but not used for dynamic graph.
if support is None or len(support) == 0:
return []
graph_ordered = []
mask = 1 - torch.eye(support[0].shape[0]).to(support[0].device)
for graph in support:
k_1_order = graph # 1 order
graph_ordered.append(k_1_order * mask)
# e.g., order = 3, k=[2, 3]; order = 2, k=[2]
for k in range(2, self.k_s+1):
k_1_order = torch.matmul(graph, k_1_order)
graph_ordered.append(k_1_order * mask)
# get st localed graph
st_local_graph = []
for graph in graph_ordered:
graph = graph.unsqueeze(-2).expand(-1, self.k_t, -1)
graph = graph.reshape(
graph.shape[0], graph.shape[1] * graph.shape[2])
# [num_nodes, kernel_size x num_nodes]
st_local_graph.append(graph)
# [order, num_nodes, kernel_size x num_nodes]
return st_local_graph
def forward(self, X, dynamic_graph, static_graph=None):
# X: [bs, seq, nodes, feat]
# [bs, seq, num_nodes, ks, num_feat]
X = X.unfold(1, self.k_t, 1).permute(0, 1, 2, 4, 3)
# seq_len is changing
batch_size, seq_len, num_nodes, kernel_size, num_feat = X.shape
# support - 只保留动态图
support = []
if self.use_dynamic_hidden_graph and dynamic_graph:
# k_order is caled in dynamic_graph_constructor component
support = support + dynamic_graph
# parallelize
X = X.reshape(batch_size, seq_len, num_nodes, kernel_size * num_feat)
# batch_size, seq_len, num_nodes, kernel_size * hidden_dim
out = self.fc_list_updt(X)
out = self.activation(out)
out = out.view(batch_size, seq_len, num_nodes, kernel_size, num_feat)
X_0 = torch.mean(out, dim=-2)
# batch_size, seq_len, kernel_size x num_nodes, hidden_dim
X_k = out.transpose(-3, -2).reshape(batch_size,
seq_len, kernel_size*num_nodes, num_feat)
# 如果support为空直接返回X_0
if len(support) == 0:
return X_0
# Nx3N 3NxD -> NxD: batch_size, seq_len, num_nodes, hidden_dim
hidden = self.gconv(support, X_k, X_0)
return hidden

View File

@ -0,0 +1,27 @@
import torch
import torch.nn as nn
class Forecast(nn.Module):
def __init__(self, hidden_dim, forecast_hidden_dim=None, **model_args):
super().__init__()
self.k_t = model_args['k_t']
self.output_seq_len = model_args['horizon'] # 使用horizon作为目标序列长度
self.forecast_fc = nn.Linear(hidden_dim, forecast_hidden_dim)
self.model_args = model_args
def forward(self, gated_history_data, hidden_states_dif, localized_st_conv, dynamic_graph, static_graph):
predict = []
history = gated_history_data
predict.append(hidden_states_dif[:, -1, :, :].unsqueeze(1))
for _ in range(int(self.output_seq_len / self.model_args['gap'])-1):
_1 = predict[-self.k_t:]
if len(_1) < self.k_t:
sub = self.k_t - len(_1)
_2 = history[:, -sub:, :, :]
_1 = torch.cat([_2] + _1, dim=1)
else:
_1 = torch.cat(_1, dim=1)
predict.append(localized_st_conv(_1, dynamic_graph, static_graph))
predict = torch.cat(predict, dim=1)
predict = self.forecast_fc(predict)
return predict

View File

@ -0,0 +1,66 @@
import torch.nn as nn
from model.D2STGNN.dynamic_graph_conv.utils.distance import DistanceFunction
from model.D2STGNN.dynamic_graph_conv.utils.mask import Mask
from model.D2STGNN.dynamic_graph_conv.utils.normalizer import Normalizer, MultiOrder
class DynamicGraphConstructor(nn.Module):
def __init__(self, **model_args):
super().__init__()
# model args
self.k_s = model_args['k_s'] # spatial order
self.k_t = model_args['k_t'] # temporal kernel size
# hidden dimension of
self.hidden_dim = model_args['num_hidden']
# trainable node embedding dimension
self.node_dim = model_args['node_hidden']
self.distance_function = DistanceFunction(**model_args)
self.mask = Mask(**model_args)
self.normalizer = Normalizer()
self.multi_order = MultiOrder(order=self.k_s)
def st_localization(self, graph_ordered):
st_local_graph = []
for modality_i in graph_ordered:
for k_order_graph in modality_i:
k_order_graph = k_order_graph.unsqueeze(
-2).expand(-1, -1, self.k_t, -1)
k_order_graph = k_order_graph.reshape(
k_order_graph.shape[0], k_order_graph.shape[1], k_order_graph.shape[2] * k_order_graph.shape[3])
st_local_graph.append(k_order_graph)
return st_local_graph
def forward(self, **inputs):
"""Dynamic graph learning module.
Args:
history_data (torch.Tensor): input data with shape (B, L, N, D)
node_embedding_u (torch.Parameter): node embedding E_u
node_embedding_d (torch.Parameter): node embedding E_d
time_in_day_feat (torch.Parameter): time embedding T_D
day_in_week_feat (torch.Parameter): time embedding T_W
Returns:
list: dynamic graphs
"""
X = inputs['history_data']
E_d = inputs['node_d'] # 参数名改为node_d
E_u = inputs['node_u'] # 参数名改为node_u
T_D = inputs['time_in_day_feat']
D_W = inputs['day_in_week_feat']
# distance calculation
dist_mx = self.distance_function(X, E_d, E_u, T_D, D_W)
# mask
dist_mx = self.mask(dist_mx)
# normalization
dist_mx = self.normalizer(dist_mx)
# multi order
mul_mx = self.multi_order(dist_mx)
# spatial temporal localization
dynamic_graphs = self.st_localization(mul_mx)
return dynamic_graphs

View File

@ -0,0 +1,59 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistanceFunction(nn.Module):
def __init__(self, **model_args):
super().__init__()
# attributes
self.hidden_dim = model_args['num_hidden']
self.node_dim = model_args['node_hidden']
self.time_slot_emb_dim = self.hidden_dim
self.input_seq_len = model_args['seq_len']
# Time Series Feature Extraction
self.dropout = nn.Dropout(model_args['dropout'])
self.fc_ts_emb1 = nn.Linear(self.input_seq_len, self.hidden_dim * 2)
self.fc_ts_emb2 = nn.Linear(self.hidden_dim * 2, self.hidden_dim)
self.ts_feat_dim= self.hidden_dim
# Time Slot Embedding Extraction
self.time_slot_embedding = nn.Linear(model_args['time_emb_dim'], self.time_slot_emb_dim)
# Distance Score
self.all_feat_dim = self.ts_feat_dim + self.node_dim + model_args['time_emb_dim']*2
self.WQ = nn.Linear(self.all_feat_dim, self.hidden_dim, bias=False)
self.WK = nn.Linear(self.all_feat_dim, self.hidden_dim, bias=False)
self.bn = nn.BatchNorm1d(self.hidden_dim*2)
def reset_parameters(self):
# 初始化所有线性层的参数
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, X, E_d, E_u, T_D, D_W):
# last pooling
T_D = T_D[:, -1, :, :]
D_W = D_W[:, -1, :, :]
# dynamic information
X = X[:, :, :, 0].transpose(1, 2).contiguous() # X->[batch_size, seq_len, num_nodes]->[batch_size, num_nodes, seq_len]
[batch_size, num_nodes, seq_len] = X.shape
X = X.view(batch_size * num_nodes, seq_len)
dy_feat = self.fc_ts_emb2(self.dropout(self.bn(F.relu(self.fc_ts_emb1(X))))) # [batchsize, num_nodes, hidden_dim]
dy_feat = dy_feat.view(batch_size, num_nodes, -1)
# node embedding
emb1 = E_d.unsqueeze(0).expand(batch_size, -1, -1)
emb2 = E_u.unsqueeze(0).expand(batch_size, -1, -1)
# distance calculation
X1 = torch.cat([dy_feat, T_D, D_W, emb1], dim=-1) # hidden state for calculating distance
X2 = torch.cat([dy_feat, T_D, D_W, emb2], dim=-1) # hidden state for calculating distance
X = [X1, X2]
adjacent_list = []
for _ in X:
Q = self.WQ(_)
K = self.WK(_)
QKT = torch.bmm(Q, K.transpose(-1, -2)) / math.sqrt(self.hidden_dim)
W = torch.softmax(QKT, dim=-1)
adjacent_list.append(W)
return adjacent_list

View File

@ -0,0 +1,21 @@
import torch
import torch.nn as nn
class Mask(nn.Module):
def __init__(self, **model_args):
super().__init__()
self.mask = model_args.get('adjs', None) # 允许adjs为None
def _mask(self, index, adj):
if self.mask is None or len(self.mask) == 0:
# 如果没有预定义的邻接矩阵直接返回原始的adj
return adj
else:
mask = self.mask[index] + torch.ones_like(self.mask[index]) * 1e-7
return mask.to(adj.device) * adj
def forward(self, adj):
result = []
for index, _ in enumerate(adj):
result.append(self._mask(index, _))
return result

View File

@ -0,0 +1,42 @@
import torch
import torch.nn as nn
def remove_nan_inf(x):
"""移除张量中的nan和inf值"""
x = torch.where(torch.isnan(x) | torch.isinf(x), torch.zeros_like(x), x)
return x
class Normalizer(nn.Module):
def __init__(self):
super().__init__()
def _norm(self, graph):
degree = torch.sum(graph, dim=2)
degree = remove_nan_inf(1 / degree)
degree = torch.diag_embed(degree)
normed_graph = torch.bmm(degree, graph)
return normed_graph
def forward(self, adj):
return [self._norm(_) for _ in adj]
class MultiOrder(nn.Module):
def __init__(self, order=2):
super().__init__()
self.order = order
def _multi_order(self, graph):
graph_ordered = []
k_1_order = graph # 1 order
mask = torch.eye(graph.shape[1]).to(graph.device)
mask = 1 - mask
graph_ordered.append(k_1_order * mask)
for k in range(2, self.order+1): # e.g., order = 3, k=[2, 3]; order = 2, k=[2]
k_1_order = torch.matmul(k_1_order, graph)
graph_ordered.append(k_1_order * mask)
return graph_ordered
def forward(self, adj):
return [self._multi_order(_) for _ in adj]

View File

@ -0,0 +1,31 @@
import torch
import torch.nn as nn
class Forecast(nn.Module):
def __init__(self, hidden_dim, fk_dim, **model_args):
super().__init__()
self.output_seq_len = model_args['seq_len']
self.model_args = model_args
self.forecast_fc = nn.Linear(hidden_dim, fk_dim)
def forward(self, X, RNN_H, Z, transformer_layer, rnn_layer, pe):
[batch_size, _, num_nodes, num_feat] = X.shape
predict = [Z[-1, :, :].unsqueeze(0)]
for _ in range(int(self.output_seq_len / self.model_args['gap'])-1):
# RNN
_gru = rnn_layer.gru_cell(predict[-1][0], RNN_H[-1]).unsqueeze(0)
RNN_H = torch.cat([RNN_H, _gru], dim=0)
# Positional Encoding
if pe is not None:
RNN_H = pe(RNN_H)
# Transformer
_Z = transformer_layer(_gru, K=RNN_H, V=RNN_H)
predict.append(_Z)
predict = torch.cat(predict, dim=0)
predict = predict.reshape(-1, batch_size, num_nodes, num_feat)
predict = predict.transpose(0, 1)
predict = self.forecast_fc(predict)
return predict

View File

@ -0,0 +1,86 @@
import math
import torch
import torch.nn as nn
from model.D2STGNN.decouple.residual_decomp import ResidualDecomp
from model.D2STGNN.inherent_block.inh_model import RNNLayer, TransformerLayer
from model.D2STGNN.inherent_block.forecast import Forecast
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=None, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, X):
X = X + self.pe[:X.size(0)]
X = self.dropout(X)
return X
class InhBlock(nn.Module):
def __init__(self, hidden_dim, num_heads=4, bias=True, forecast_hidden_dim=256, **model_args):
"""Inherent block
Args:
hidden_dim (int): hidden dimension
num_heads (int, optional): number of heads of MSA. Defaults to 4.
bias (bool, optional): if use bias. Defaults to True.
forecast_hidden_dim (int, optional): forecast branch hidden dimension. Defaults to 256.
"""
super().__init__()
self.num_feat = hidden_dim
self.hidden_dim = hidden_dim
# inherent model
self.pos_encoder = PositionalEncoding(hidden_dim, model_args['dropout'])
self.rnn_layer = RNNLayer(hidden_dim, model_args['dropout'])
self.transformer_layer = TransformerLayer(hidden_dim, num_heads, model_args['dropout'], bias)
# forecast branch
self.forecast_block = Forecast(hidden_dim, forecast_hidden_dim, **model_args)
# backcast branch
self.backcast_fc = nn.Linear(hidden_dim, hidden_dim)
# residual decomposition
self.residual_decompose = ResidualDecomp([-1, -1, -1, hidden_dim])
def forward(self, hidden_inherent_signal):
"""Inherent block, containing the inherent model, forecast branch, backcast branch, and the residual decomposition link.
Args:
hidden_inherent_signal (torch.Tensor): hidden inherent signal with shape [batch_size, seq_len, num_nodes, num_feat].
Returns:
torch.Tensor: the output after the decoupling mechanism (backcast branch and the residual link), which should be fed to the next decouple layer.
Shape: [batch_size, seq_len, num_nodes, hidden_dim].
torch.Tensor: the output of the forecast branch, which will be used to make final prediction.
Shape: [batch_size, seq_len'', num_nodes, forecast_hidden_dim]. seq_len'' = future_len / gap.
In order to reduce the error accumulation in the AR forecasting strategy, we let each hidden state generate the prediction of gap points, instead of a single point.
"""
[batch_size, seq_len, num_nodes, num_feat] = hidden_inherent_signal.shape
# inherent model
## rnn
hidden_states_rnn = self.rnn_layer(hidden_inherent_signal)
## pe
hidden_states_rnn = self.pos_encoder(hidden_states_rnn)
## MSA
hidden_states_inh = self.transformer_layer(hidden_states_rnn, hidden_states_rnn, hidden_states_rnn)
# forecast branch
forecast_hidden = self.forecast_block(hidden_inherent_signal, hidden_states_rnn, hidden_states_inh, self.transformer_layer, self.rnn_layer, self.pos_encoder)
# backcast branch
hidden_states_inh = hidden_states_inh.reshape(seq_len, batch_size, num_nodes, num_feat)
hidden_states_inh = hidden_states_inh.transpose(0, 1)
backcast_seq = self.backcast_fc(hidden_states_inh)
backcast_seq_res= self.residual_decompose(hidden_inherent_signal, backcast_seq)
return backcast_seq_res, forecast_hidden

View File

@ -0,0 +1,35 @@
import torch as th
import torch.nn as nn
from torch.nn import MultiheadAttention
class RNNLayer(nn.Module):
def __init__(self, hidden_dim, dropout=None):
super().__init__()
self.hidden_dim = hidden_dim
self.gru_cell = nn.GRUCell(hidden_dim, hidden_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, X):
[batch_size, seq_len, num_nodes, hidden_dim] = X.shape
X = X.transpose(1, 2).reshape(batch_size * num_nodes, seq_len, hidden_dim)
hx = th.zeros_like(X[:, 0, :])
output = []
for _ in range(X.shape[1]):
hx = self.gru_cell(X[:, _, :], hx)
output.append(hx)
output = th.stack(output, dim=0)
output = self.dropout(output)
return output
class TransformerLayer(nn.Module):
def __init__(self, hidden_dim, num_heads=4, dropout=None, bias=True):
super().__init__()
self.multi_head_self_attention = MultiheadAttention(hidden_dim, num_heads, dropout=dropout, bias=bias)
self.dropout = nn.Dropout(dropout)
def forward(self, X, K, V):
hidden_states_MSA = self.multi_head_self_attention(X, K, V)[0]
hidden_states_MSA = self.dropout(hidden_states_MSA)
return hidden_states_MSA

View File

@ -0,0 +1,7 @@
[
{
"name": "D2STGNN",
"module": "model.D2STGNN.D2STGNN",
"entry": "D2STGNN"
}
]

View File

@ -6,14 +6,15 @@ import utils.initializer as init
from dataloader.loader_selector import get_dataloader from dataloader.loader_selector import get_dataloader
from trainer.trainer_selector import select_trainer from trainer.trainer_selector import select_trainer
def read_config(config_path): def read_config(config_path):
with open(config_path, "r") as file: with open(config_path, "r") as file:
config = yaml.safe_load(file) config = yaml.safe_load(file)
# 全局配置 # 全局配置
device = "cpu" # 指定设备为cuda:0 device = "cpu" # 指定设备为cuda:0
seed = 2023 # 随机种子 seed = 2023 # 随机种子
epochs = 1 # 训练轮数 epochs = 1 # 训练轮数
# 拷贝项 # 拷贝项
config["basic"]["device"] = device config["basic"]["device"] = device
@ -23,6 +24,7 @@ def read_config(config_path):
config["train"]["epochs"] = epochs config["train"]["epochs"] = epochs
return config return config
def run(config): def run(config):
init.init_seed(config["basic"]["seed"]) init.init_seed(config["basic"]["seed"])
model = init.init_model(config) model = init.init_model(config)
@ -34,10 +36,15 @@ def run(config):
init.create_logs(config) init.create_logs(config)
trainer = select_trainer( trainer = select_trainer(
model, model,
loss, optimizer, loss,
train_loader, val_loader, test_loader, scaler, optimizer,
train_loader,
val_loader,
test_loader,
scaler,
config, config,
lr_scheduler, extra_data, lr_scheduler,
extra_data,
) )
# 开始训练 # 开始训练
@ -54,17 +61,20 @@ def run(config):
) )
trainer.test( trainer.test(
model.to(config["basic"]["device"]), model.to(config["basic"]["device"]),
trainer.args, test_loader, scaler, trainer.args,
test_loader,
scaler,
trainer.logger, trainer.logger,
) )
case _: case _:
raise ValueError(f"Unsupported mode: {config['basic']['mode']}") raise ValueError(f"Unsupported mode: {config['basic']['mode']}")
def main(model, data, debug=False):
def main(model_list, data, debug=False):
# 我的调试开关,不做测试就填 str(False) # 我的调试开关,不做测试就填 str(False)
# os.environ["TRY"] = str(False) # os.environ["TRY"] = str(False)
os.environ["TRY"] = str(debug) os.environ["TRY"] = str(debug)
for model in model_list: for model in model_list:
for dataset in data: for dataset in data:
config_path = f"./config/{model}/{dataset}.yaml" config_path = f"./config/{model}/{dataset}.yaml"
@ -77,22 +87,25 @@ def main(model, data, debug=False):
except Exception as e: except Exception as e:
import traceback import traceback
import sys, traceback import sys, traceback
tb_lines = traceback.format_exc().splitlines() tb_lines = traceback.format_exc().splitlines()
# 如果不是AssertionError才打印完整traceback # 如果不是AssertionError才打印完整traceback
if not tb_lines[-1].startswith("AssertionError"): if not tb_lines[-1].startswith("AssertionError"):
traceback.print_exc() traceback.print_exc()
print(f"\n===== {model} on {dataset} failed with error: {e} =====\n") print(
f"\n===== {model} on {dataset} failed with error: {e} =====\n"
)
else: else:
run(config) run(config)
if __name__ == "__main__": if __name__ == "__main__":
# 调试用 # 调试用
# model_list = ["iTransformer", "PatchTST", "HI"] # model_list = ["iTransformer", "PatchTST", "HI"]
model_list = ["STNorm"] model_list = ["D2STGNN"]
# model_list = ["PatchTST"] # model_list = ["PatchTST"]
# dataset_list = ["AirQuality"] # dataset_list = ["AirQuality"]
dataset_list = ["BJTaxi-InFlow", "BJTaxi-OutFlow"] # dataset_list = ["BJTaxi-InFlow", "BJTaxi-OutFlow"]
# dataset_list = ["AirQuality", "PEMS-BAY", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"] # dataset_list = ["AirQuality", "PEMS-BAY", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"]
main(model_list, dataset_list, debug = True) dataset_list = ["BJTaxi-OutFlow"]
main(model_list, dataset_list, debug=True)