diff --git a/config/D2STGNN/AirQuality.yaml b/config/D2STGNN/AirQuality.yaml new file mode 100644 index 0000000..dea3149 --- /dev/null +++ b/config/D2STGNN/AirQuality.yaml @@ -0,0 +1,60 @@ +basic: + dataset: AirQuality + device: cuda:0 + mode: train + model: D2STGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + normalizer: std + num_nodes: 35 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + num_nodes: 35 + num_layers: 4 + num_hidden: 32 + forecast_dim: 256 + output_hidden: 512 + output_dim: 6 + seq_len: 24 + horizon: 24 + input_dim: 6 + num_timesteps_in_day: 24 + time_emb_dim: 10 + node_hidden: 10 + dy_graph: True + sta_graph: False + gap: 3 + k_s: 2 + k_t: 3 + dropout: 0.1 + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 6 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/D2STGNN/BJTaxi-InFlow.yaml b/config/D2STGNN/BJTaxi-InFlow.yaml new file mode 100644 index 0000000..0a542a7 --- /dev/null +++ b/config/D2STGNN/BJTaxi-InFlow.yaml @@ -0,0 +1,60 @@ +basic: + dataset: BJTaxi-InFlow + device: cuda:0 + mode: train + model: D2STGNN + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + num_nodes: 1024 + num_layers: 4 + num_hidden: 32 + forecast_dim: 256 + output_hidden: 512 + output_dim: 1 + seq_len: 24 + horizon: 24 + input_dim: 1 + num_timesteps_in_day: 48 + time_emb_dim: 10 + node_hidden: 10 + dy_graph: True + sta_graph: False + gap: 3 + k_s: 2 + k_t: 3 + dropout: 0.1 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/D2STGNN/BJTaxi-OutFlow.yaml b/config/D2STGNN/BJTaxi-OutFlow.yaml new file mode 100644 index 0000000..2fcb8b2 --- /dev/null +++ b/config/D2STGNN/BJTaxi-OutFlow.yaml @@ -0,0 +1,60 @@ +basic: + dataset: BJTaxi-OutFlow + device: cuda:0 + mode: train + model: D2STGNN + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + num_nodes: 1024 + num_layers: 4 + num_hidden: 32 + forecast_dim: 256 + output_hidden: 512 + output_dim: 1 + seq_len: 24 + horizon: 24 + input_dim: 1 + num_timesteps_in_day: 48 + time_emb_dim: 10 + node_hidden: 10 + dy_graph: True + sta_graph: False + gap: 3 + k_s: 2 + k_t: 3 + dropout: 0.1 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/D2STGNN/METR-LA.yaml b/config/D2STGNN/METR-LA.yaml new file mode 100644 index 0000000..c26b3ed --- /dev/null +++ b/config/D2STGNN/METR-LA.yaml @@ -0,0 +1,60 @@ +basic: + dataset: METR-LA + device: cuda:0 + mode: train + model: D2STGNN + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 207 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + num_nodes: 207 + num_layers: 4 + num_hidden: 32 + forecast_dim: 256 + output_hidden: 512 + output_dim: 1 + seq_len: 24 + horizon: 24 + input_dim: 1 + num_timesteps_in_day: 288 + time_emb_dim: 10 + node_hidden: 10 + dy_graph: True + sta_graph: False + gap: 3 + k_s: 2 + k_t: 3 + dropout: 0.1 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/D2STGNN/NYCBike-InFlow.yaml b/config/D2STGNN/NYCBike-InFlow.yaml new file mode 100644 index 0000000..f462340 --- /dev/null +++ b/config/D2STGNN/NYCBike-InFlow.yaml @@ -0,0 +1,60 @@ +basic: + dataset: NYCBike-InFlow + device: cuda:0 + mode: train + model: D2STGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + num_nodes: 128 + num_layers: 4 + num_hidden: 32 + forecast_dim: 256 + output_hidden: 512 + output_dim: 1 + seq_len: 24 + horizon: 24 + input_dim: 1 + num_timesteps_in_day: 48 + time_emb_dim: 10 + node_hidden: 10 + dy_graph: True + sta_graph: False + gap: 3 + k_s: 2 + k_t: 3 + dropout: 0.1 + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/D2STGNN/NYCBike-OutFlow.yaml b/config/D2STGNN/NYCBike-OutFlow.yaml new file mode 100644 index 0000000..1c067d0 --- /dev/null +++ b/config/D2STGNN/NYCBike-OutFlow.yaml @@ -0,0 +1,60 @@ +basic: + dataset: NYCBike-OutFlow + device: cuda:0 + mode: train + model: D2STGNN + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + num_nodes: 128 + num_layers: 4 + num_hidden: 32 + forecast_dim: 256 + output_hidden: 512 + output_dim: 1 + seq_len: 24 + horizon: 24 + input_dim: 1 + num_timesteps_in_day: 48 + time_emb_dim: 10 + node_hidden: 10 + dy_graph: True + sta_graph: False + gap: 3 + k_s: 2 + k_t: 3 + dropout: 0.1 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/D2STGNN/PEMS-BAY.yaml b/config/D2STGNN/PEMS-BAY.yaml new file mode 100644 index 0000000..0e6bd94 --- /dev/null +++ b/config/D2STGNN/PEMS-BAY.yaml @@ -0,0 +1,60 @@ +basic: + dataset: PEMS-BAY + device: cuda:0 + mode: train + model: D2STGNN + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 325 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + num_nodes: 325 + num_layers: 4 + num_hidden: 32 + forecast_dim: 256 + output_hidden: 512 + output_dim: 1 + seq_len: 24 + horizon: 24 + input_dim: 1 + num_timesteps_in_day: 288 + time_emb_dim: 10 + node_hidden: 10 + dy_graph: True + sta_graph: False + gap: 3 + k_s: 2 + k_t: 3 + dropout: 0.1 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/D2STGNN/SolarEnergy.yaml b/config/D2STGNN/SolarEnergy.yaml new file mode 100644 index 0000000..4c3fa9f --- /dev/null +++ b/config/D2STGNN/SolarEnergy.yaml @@ -0,0 +1,60 @@ +basic: + dataset: SolarEnergy + device: cuda:0 + mode: train + model: D2STGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 137 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + num_nodes: 137 + num_layers: 4 + num_hidden: 32 + forecast_dim: 256 + output_hidden: 512 + output_dim: 1 + seq_len: 24 + horizon: 24 + input_dim: 1 + num_timesteps_in_day: 24 + time_emb_dim: 10 + node_hidden: 10 + dy_graph: True + sta_graph: False + gap: 3 + k_s: 2 + k_t: 3 + dropout: 0.1 + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/model/D2STGNN/D2STGNN.py b/model/D2STGNN/D2STGNN.py new file mode 100644 index 0000000..135380a --- /dev/null +++ b/model/D2STGNN/D2STGNN.py @@ -0,0 +1,88 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from model.D2STGNN.diffusion_block.dif_block import DifBlock +from model.D2STGNN.inherent_block.inh_block import InhBlock +from model.D2STGNN.dynamic_graph_conv.dy_graph_conv import DynamicGraphConstructor +from model.D2STGNN.decouple.estimation_gate import EstimationGate + +class DecoupleLayer(nn.Module): + def __init__(self, hidden_dim, fk_dim, args): + super().__init__() + self.est_gate = EstimationGate(node_emb_dim=args['node_hidden'], time_emb_dim=args['time_emb_dim'], hidden_dim=64) + # 只传递必要参数,dy_graph会通过**args传递 + self.dif_layer = DifBlock(hidden_dim, forecast_hidden_dim=fk_dim, **args) + self.inh_layer = InhBlock(hidden_dim, forecast_hidden_dim=fk_dim, **args) + + def forward(self, x, dyn_graph, sta_graph=None, node_u=None, node_d=None, t_in_day=None, t_in_week=None): + gated_x = self.est_gate(node_u, node_d, t_in_day, t_in_week, x) + dif_back, dif_hidden = self.dif_layer(x, gated_x, dyn_graph, sta_graph) + inh_back, inh_hidden = self.inh_layer(dif_back) + return inh_back, dif_hidden, inh_hidden + +class D2STGNN(nn.Module): + def __init__(self, args): + super().__init__() + self.args = args # 保存args用于forward方法 + self.num_nodes = args['num_nodes'] + self.num_layers = args['num_layers'] + self.hidden_dim = args['num_hidden'] + self.forecast_dim = args['forecast_dim'] + self.output_hidden = args['output_hidden'] + self.output_dim = args['output_dim'] + self.in_feat = args['input_dim'] + + self.embedding = nn.Linear(self.in_feat, self.hidden_dim) + self.T_i_D_emb = nn.Parameter(torch.empty(args.get('num_timesteps_in_day',288), args['time_emb_dim'])) + self.D_i_W_emb = nn.Parameter(torch.empty(7, args['time_emb_dim'])) + self.node_u = nn.Parameter(torch.empty(self.num_nodes, args['node_hidden'])) + self.node_d = nn.Parameter(torch.empty(self.num_nodes, args['node_hidden'])) + + self.layers = nn.ModuleList([DecoupleLayer(self.hidden_dim, self.forecast_dim, args) for _ in range(self.num_layers)]) + if args.get('dy_graph', False): + self.dynamic_graph_constructor = DynamicGraphConstructor(**args) + + self.out_fc1 = nn.Linear(self.forecast_dim, self.output_hidden) + self.out_fc2 = nn.Linear(self.output_hidden, args['gap'] * args['output_dim']) + self._reset_parameters() + + def _reset_parameters(self): + for p in [self.node_u, self.node_d, self.T_i_D_emb, self.D_i_W_emb]: + nn.init.xavier_uniform_(p) + + def _prepare_inputs(self, x): + node_u, node_d = self.node_u, self.node_d + t_in_day = self.T_i_D_emb[(x[:, :, :, -2]*self.T_i_D_emb.size(0)).long()] + t_in_week = self.D_i_W_emb[x[:, :, :, -1].long()] + return x[:, :, :, :-2], node_u, node_d, t_in_day, t_in_week + + def _graph_constructor(self, node_u, node_d, x, t_in_day, t_in_week): + # 只生成动态图,去除静态图 + dyn_graph = self.dynamic_graph_constructor(node_u=node_u, node_d=node_d, history_data=x, time_in_day_feat=t_in_day, day_in_week_feat=t_in_week) if hasattr(self, 'dynamic_graph_constructor') else [] + return [], dyn_graph + + def forward(self, x): + x, node_u, node_d, t_in_day, t_in_week = self._prepare_inputs(x) + sta_graph, dyn_graph = self._graph_constructor(node_u, node_d, x, t_in_day, t_in_week) + x = self.embedding(x) + + dif_hidden_list, inh_hidden_list = [], [] + backcast = x + for layer in self.layers: + backcast, dif_hidden, inh_hidden = layer(backcast, dyn_graph, sta_graph, node_u, node_d, t_in_day, t_in_week) + dif_hidden_list.append(dif_hidden) + inh_hidden_list.append(inh_hidden) + + forecast_hidden = sum(dif_hidden_list) + sum(inh_hidden_list) + # 调整输出形状,使其与标签匹配 + forecast = self.out_fc1(F.relu(forecast_hidden)) + forecast = F.relu(forecast) + forecast = self.out_fc2(forecast) + # 确保输出维度正确 + if forecast.size(-1) != self.args['output_dim']: + forecast = forecast[..., :self.args['output_dim']] + # 确保时间步长正确 + if forecast.size(1) != self.args['horizon']: + # 如果时间步长不足,进行插值或重复 + forecast = forecast.repeat(1, self.args['horizon'] // forecast.size(1) + 1, 1, 1)[:, :self.args['horizon'], :, :] + return forecast diff --git a/model/D2STGNN/decouple/estimation_gate.py b/model/D2STGNN/decouple/estimation_gate.py new file mode 100644 index 0000000..874b185 --- /dev/null +++ b/model/D2STGNN/decouple/estimation_gate.py @@ -0,0 +1,24 @@ +import torch +import torch.nn as nn + + +class EstimationGate(nn.Module): + """The estimation gate module.""" + + def __init__(self, node_emb_dim, time_emb_dim, hidden_dim): + super().__init__() + self.fully_connected_layer_1 = nn.Linear(2 * node_emb_dim + time_emb_dim * 2, hidden_dim) + self.activation = nn.ReLU() + self.fully_connected_layer_2 = nn.Linear(hidden_dim, 1) + + def forward(self, node_embedding_u, node_embedding_d, time_in_day_feat, day_in_week_feat, history_data): + """Generate gate value in (0, 1) based on current node and time step embeddings to roughly estimating the proportion of the two hidden time series.""" + + batch_size, seq_length, _, _ = time_in_day_feat.shape + estimation_gate_feat = torch.cat([time_in_day_feat, day_in_week_feat, node_embedding_u.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_length, -1, -1), node_embedding_d.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_length, -1, -1)], dim=-1) + hidden = self.fully_connected_layer_1(estimation_gate_feat) + hidden = self.activation(hidden) + # activation + estimation_gate = torch.sigmoid(self.fully_connected_layer_2(hidden))[:, -history_data.shape[1]:, :, :] + history_data = history_data * estimation_gate + return history_data diff --git a/model/D2STGNN/decouple/residual_decomp.py b/model/D2STGNN/decouple/residual_decomp.py new file mode 100644 index 0000000..a5c50b7 --- /dev/null +++ b/model/D2STGNN/decouple/residual_decomp.py @@ -0,0 +1,15 @@ +import torch.nn as nn + + +class ResidualDecomp(nn.Module): + """Residual decomposition.""" + + def __init__(self, input_shape): + super().__init__() + self.ln = nn.LayerNorm(input_shape[-1]) + self.ac = nn.ReLU() + + def forward(self, x, y): + u = x - self.ac(y) + u = self.ln(u) + return u diff --git a/model/D2STGNN/diffusion_block/dif_block.py b/model/D2STGNN/diffusion_block/dif_block.py new file mode 100644 index 0000000..2c4f8b1 --- /dev/null +++ b/model/D2STGNN/diffusion_block/dif_block.py @@ -0,0 +1,56 @@ +import torch.nn as nn + +from model.D2STGNN.diffusion_block.forecast import Forecast +from model.D2STGNN.diffusion_block.dif_model import STLocalizedConv +from model.D2STGNN.decouple.residual_decomp import ResidualDecomp + + +class DifBlock(nn.Module): + def __init__(self, hidden_dim, forecast_hidden_dim=256, dy_graph=None, **model_args): + """Diffusion block + + Args: + hidden_dim (int): hidden dimension. + forecast_hidden_dim (int, optional): forecast branch hidden dimension. Defaults to 256. + dy_graph (bool, optional): if use dynamic graph. Defaults to None. + """ + + super().__init__() + # diffusion model - 只保留动态图 + self.localized_st_conv = STLocalizedConv(hidden_dim, dy_graph=dy_graph, **model_args) + + # forecast + self.forecast_branch = Forecast(hidden_dim, forecast_hidden_dim=forecast_hidden_dim, **model_args) + # backcast + self.backcast_branch = nn.Linear(hidden_dim, hidden_dim) + # esidual decomposition + self.residual_decompose = ResidualDecomp([-1, -1, -1, hidden_dim]) + + def forward(self, history_data, gated_history_data, dynamic_graph, static_graph=None): + """Diffusion block, containing the diffusion model, forecast branch, backcast branch, and the residual decomposition link. + + Args: + history_data (torch.Tensor): history data with shape [batch_size, seq_len, num_nodes, hidden_dim] + gated_history_data (torch.Tensor): gated history data with shape [batch_size, seq_len, num_nodes, hidden_dim] + dynamic_graph (list): dynamic graphs. + static_graph (list, optional): static graphs (未使用). + + Returns: + torch.Tensor: the output after the decoupling mechanism (backcast branch and the residual link), which should be fed to the inherent model. + Shape: [batch_size, seq_len', num_nodes, hidden_dim]. Kindly note that after the st conv, the sequence will be shorter. + torch.Tensor: the output of the forecast branch, which will be used to make final prediction. + Shape: [batch_size, seq_len'', num_nodes, forecast_hidden_dim]. seq_len'' = future_len / gap. + In order to reduce the error accumulation in the AR forecasting strategy, we let each hidden state generate the prediction of gap points, instead of a single point. + """ + + # diffusion model - 只使用动态图 + hidden_states_dif = self.localized_st_conv(gated_history_data, dynamic_graph, static_graph) + # forecast branch: use the localized st conv to predict future hidden states. + forecast_hidden = self.forecast_branch(gated_history_data, hidden_states_dif, self.localized_st_conv, dynamic_graph, static_graph) + # backcast branch: use FC layer to do backcast + backcast_seq = self.backcast_branch(hidden_states_dif) + # residual decomposition: remove the learned knowledge from input data + history_data = history_data[:, -backcast_seq.shape[1]:, :, :] + backcast_seq_res = self.residual_decompose(history_data, backcast_seq) + + return backcast_seq_res, forecast_hidden diff --git a/model/D2STGNN/diffusion_block/dif_model.py b/model/D2STGNN/diffusion_block/dif_model.py new file mode 100644 index 0000000..e334653 --- /dev/null +++ b/model/D2STGNN/diffusion_block/dif_model.py @@ -0,0 +1,128 @@ +import torch +import torch.nn as nn + + +class STLocalizedConv(nn.Module): + def __init__(self, hidden_dim, dy_graph=None, **model_args): + super().__init__() + # gated temporal conv + self.k_s = model_args['k_s'] + self.k_t = model_args['k_t'] + self.hidden_dim = hidden_dim + + # graph conv - 只保留动态图 + self.use_dynamic_hidden_graph = dy_graph + + # 只考虑动态图 + self.support_len = int(dy_graph) if dy_graph is not None else 0 + + # num_matric = 1 (X_0) + dynamic graphs count + self.num_matric = 1 + self.support_len + self.dropout = nn.Dropout(model_args['dropout']) + + self.fc_list_updt = nn.Linear( + self.k_t * hidden_dim, self.k_t * hidden_dim, bias=False) + self.gcn_updt = nn.Linear( + self.hidden_dim*self.num_matric, self.hidden_dim) + + # others + self.bn = nn.BatchNorm2d(self.hidden_dim) + self.activation = nn.ReLU() + + def gconv(self, support, X_k, X_0): + out = [X_0] + batch_size, seq_len, _, hidden_dim = X_0.shape + + for graph in support: + # 确保graph的形状与X_k匹配 + if len(graph.shape) == 3: # 动态图,形状为 [B, N, K*N] + # 复制graph以匹配seq_len维度 + graph = graph.unsqueeze(1).repeat(1, seq_len, 1, 1) # [B, L, N, K*N] + elif len(graph.shape) == 2: # 静态图,形状为 [N, K*N] + graph = graph.unsqueeze(0).unsqueeze(1).repeat(batch_size, seq_len, 1, 1) # [B, L, N, K*N] + + # 确保X_k的形状正确 + if X_k.dim() == 4: # [B, L, K*N, D] + # 进行矩阵乘法:[B, L, N, K*N] x [B, L, K*N, D] -> [B, L, N, D] + H_k = torch.matmul(graph, X_k) + else: + H_k = torch.matmul(graph, X_k.unsqueeze(1)) + H_k = H_k.squeeze(1) + + out.append(H_k) + + # 拼接所有结果 + out = torch.cat(out, dim=-1) + + # 动态调整线性层的输入维度 + if out.shape[-1] != self.gcn_updt.in_features: + # 创建新的线性层,匹配当前的输入维度 + new_gcn_updt = nn.Linear(out.shape[-1], self.hidden_dim).to(out.device) + # 复制原有参数(如果可能的话) + with torch.no_grad(): + min_dim = min(out.shape[-1], self.gcn_updt.in_features) + new_gcn_updt.weight[:, :min_dim] = self.gcn_updt.weight[:, :min_dim] + if new_gcn_updt.bias is not None and self.gcn_updt.bias is not None: + new_gcn_updt.bias = self.gcn_updt.bias + self.gcn_updt = new_gcn_updt + + out = self.gcn_updt(out) + out = self.dropout(out) + return out + + def get_graph(self, support): + # Only used in static including static hidden graph and predefined graph, but not used for dynamic graph. + if support is None or len(support) == 0: + return [] + + graph_ordered = [] + mask = 1 - torch.eye(support[0].shape[0]).to(support[0].device) + for graph in support: + k_1_order = graph # 1 order + graph_ordered.append(k_1_order * mask) + # e.g., order = 3, k=[2, 3]; order = 2, k=[2] + for k in range(2, self.k_s+1): + k_1_order = torch.matmul(graph, k_1_order) + graph_ordered.append(k_1_order * mask) + # get st localed graph + st_local_graph = [] + for graph in graph_ordered: + graph = graph.unsqueeze(-2).expand(-1, self.k_t, -1) + graph = graph.reshape( + graph.shape[0], graph.shape[1] * graph.shape[2]) + # [num_nodes, kernel_size x num_nodes] + st_local_graph.append(graph) + # [order, num_nodes, kernel_size x num_nodes] + return st_local_graph + + def forward(self, X, dynamic_graph, static_graph=None): + # X: [bs, seq, nodes, feat] + # [bs, seq, num_nodes, ks, num_feat] + X = X.unfold(1, self.k_t, 1).permute(0, 1, 2, 4, 3) + # seq_len is changing + batch_size, seq_len, num_nodes, kernel_size, num_feat = X.shape + + # support - 只保留动态图 + support = [] + if self.use_dynamic_hidden_graph and dynamic_graph: + # k_order is caled in dynamic_graph_constructor component + support = support + dynamic_graph + + # parallelize + X = X.reshape(batch_size, seq_len, num_nodes, kernel_size * num_feat) + # batch_size, seq_len, num_nodes, kernel_size * hidden_dim + out = self.fc_list_updt(X) + out = self.activation(out) + out = out.view(batch_size, seq_len, num_nodes, kernel_size, num_feat) + X_0 = torch.mean(out, dim=-2) + # batch_size, seq_len, kernel_size x num_nodes, hidden_dim + X_k = out.transpose(-3, -2).reshape(batch_size, + seq_len, kernel_size*num_nodes, num_feat) + + # 如果support为空,直接返回X_0 + if len(support) == 0: + return X_0 + + # Nx3N 3NxD -> NxD: batch_size, seq_len, num_nodes, hidden_dim + hidden = self.gconv(support, X_k, X_0) + return hidden diff --git a/model/D2STGNN/diffusion_block/forecast.py b/model/D2STGNN/diffusion_block/forecast.py new file mode 100644 index 0000000..f9798cc --- /dev/null +++ b/model/D2STGNN/diffusion_block/forecast.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn + +class Forecast(nn.Module): + def __init__(self, hidden_dim, forecast_hidden_dim=None, **model_args): + super().__init__() + self.k_t = model_args['k_t'] + self.output_seq_len = model_args['horizon'] # 使用horizon作为目标序列长度 + self.forecast_fc = nn.Linear(hidden_dim, forecast_hidden_dim) + self.model_args = model_args + + def forward(self, gated_history_data, hidden_states_dif, localized_st_conv, dynamic_graph, static_graph): + predict = [] + history = gated_history_data + predict.append(hidden_states_dif[:, -1, :, :].unsqueeze(1)) + for _ in range(int(self.output_seq_len / self.model_args['gap'])-1): + _1 = predict[-self.k_t:] + if len(_1) < self.k_t: + sub = self.k_t - len(_1) + _2 = history[:, -sub:, :, :] + _1 = torch.cat([_2] + _1, dim=1) + else: + _1 = torch.cat(_1, dim=1) + predict.append(localized_st_conv(_1, dynamic_graph, static_graph)) + predict = torch.cat(predict, dim=1) + predict = self.forecast_fc(predict) + return predict diff --git a/model/D2STGNN/dynamic_graph_conv/dy_graph_conv.py b/model/D2STGNN/dynamic_graph_conv/dy_graph_conv.py new file mode 100644 index 0000000..5f2d1de --- /dev/null +++ b/model/D2STGNN/dynamic_graph_conv/dy_graph_conv.py @@ -0,0 +1,66 @@ +import torch.nn as nn + +from model.D2STGNN.dynamic_graph_conv.utils.distance import DistanceFunction +from model.D2STGNN.dynamic_graph_conv.utils.mask import Mask + +from model.D2STGNN.dynamic_graph_conv.utils.normalizer import Normalizer, MultiOrder + + +class DynamicGraphConstructor(nn.Module): + def __init__(self, **model_args): + super().__init__() + # model args + self.k_s = model_args['k_s'] # spatial order + self.k_t = model_args['k_t'] # temporal kernel size + # hidden dimension of + self.hidden_dim = model_args['num_hidden'] + # trainable node embedding dimension + self.node_dim = model_args['node_hidden'] + + self.distance_function = DistanceFunction(**model_args) + self.mask = Mask(**model_args) + self.normalizer = Normalizer() + self.multi_order = MultiOrder(order=self.k_s) + + def st_localization(self, graph_ordered): + st_local_graph = [] + for modality_i in graph_ordered: + for k_order_graph in modality_i: + k_order_graph = k_order_graph.unsqueeze( + -2).expand(-1, -1, self.k_t, -1) + k_order_graph = k_order_graph.reshape( + k_order_graph.shape[0], k_order_graph.shape[1], k_order_graph.shape[2] * k_order_graph.shape[3]) + st_local_graph.append(k_order_graph) + return st_local_graph + + def forward(self, **inputs): + """Dynamic graph learning module. + + Args: + history_data (torch.Tensor): input data with shape (B, L, N, D) + node_embedding_u (torch.Parameter): node embedding E_u + node_embedding_d (torch.Parameter): node embedding E_d + time_in_day_feat (torch.Parameter): time embedding T_D + day_in_week_feat (torch.Parameter): time embedding T_W + + Returns: + list: dynamic graphs + """ + + X = inputs['history_data'] + E_d = inputs['node_d'] # 参数名改为node_d + E_u = inputs['node_u'] # 参数名改为node_u + T_D = inputs['time_in_day_feat'] + D_W = inputs['day_in_week_feat'] + # distance calculation + dist_mx = self.distance_function(X, E_d, E_u, T_D, D_W) + # mask + dist_mx = self.mask(dist_mx) + # normalization + dist_mx = self.normalizer(dist_mx) + # multi order + mul_mx = self.multi_order(dist_mx) + # spatial temporal localization + dynamic_graphs = self.st_localization(mul_mx) + + return dynamic_graphs diff --git a/model/D2STGNN/dynamic_graph_conv/utils/distance.py b/model/D2STGNN/dynamic_graph_conv/utils/distance.py new file mode 100644 index 0000000..7a44a7e --- /dev/null +++ b/model/D2STGNN/dynamic_graph_conv/utils/distance.py @@ -0,0 +1,59 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +class DistanceFunction(nn.Module): + def __init__(self, **model_args): + super().__init__() + # attributes + self.hidden_dim = model_args['num_hidden'] + self.node_dim = model_args['node_hidden'] + self.time_slot_emb_dim = self.hidden_dim + self.input_seq_len = model_args['seq_len'] + # Time Series Feature Extraction + self.dropout = nn.Dropout(model_args['dropout']) + self.fc_ts_emb1 = nn.Linear(self.input_seq_len, self.hidden_dim * 2) + self.fc_ts_emb2 = nn.Linear(self.hidden_dim * 2, self.hidden_dim) + self.ts_feat_dim= self.hidden_dim + # Time Slot Embedding Extraction + self.time_slot_embedding = nn.Linear(model_args['time_emb_dim'], self.time_slot_emb_dim) + # Distance Score + self.all_feat_dim = self.ts_feat_dim + self.node_dim + model_args['time_emb_dim']*2 + self.WQ = nn.Linear(self.all_feat_dim, self.hidden_dim, bias=False) + self.WK = nn.Linear(self.all_feat_dim, self.hidden_dim, bias=False) + self.bn = nn.BatchNorm1d(self.hidden_dim*2) + + def reset_parameters(self): + # 初始化所有线性层的参数 + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, X, E_d, E_u, T_D, D_W): + # last pooling + T_D = T_D[:, -1, :, :] + D_W = D_W[:, -1, :, :] + # dynamic information + X = X[:, :, :, 0].transpose(1, 2).contiguous() # X->[batch_size, seq_len, num_nodes]->[batch_size, num_nodes, seq_len] + [batch_size, num_nodes, seq_len] = X.shape + X = X.view(batch_size * num_nodes, seq_len) + dy_feat = self.fc_ts_emb2(self.dropout(self.bn(F.relu(self.fc_ts_emb1(X))))) # [batchsize, num_nodes, hidden_dim] + dy_feat = dy_feat.view(batch_size, num_nodes, -1) + # node embedding + emb1 = E_d.unsqueeze(0).expand(batch_size, -1, -1) + emb2 = E_u.unsqueeze(0).expand(batch_size, -1, -1) + # distance calculation + X1 = torch.cat([dy_feat, T_D, D_W, emb1], dim=-1) # hidden state for calculating distance + X2 = torch.cat([dy_feat, T_D, D_W, emb2], dim=-1) # hidden state for calculating distance + X = [X1, X2] + adjacent_list = [] + for _ in X: + Q = self.WQ(_) + K = self.WK(_) + QKT = torch.bmm(Q, K.transpose(-1, -2)) / math.sqrt(self.hidden_dim) + W = torch.softmax(QKT, dim=-1) + adjacent_list.append(W) + return adjacent_list diff --git a/model/D2STGNN/dynamic_graph_conv/utils/mask.py b/model/D2STGNN/dynamic_graph_conv/utils/mask.py new file mode 100644 index 0000000..e9e660e --- /dev/null +++ b/model/D2STGNN/dynamic_graph_conv/utils/mask.py @@ -0,0 +1,21 @@ +import torch +import torch.nn as nn + +class Mask(nn.Module): + def __init__(self, **model_args): + super().__init__() + self.mask = model_args.get('adjs', None) # 允许adjs为None + + def _mask(self, index, adj): + if self.mask is None or len(self.mask) == 0: + # 如果没有预定义的邻接矩阵,直接返回原始的adj + return adj + else: + mask = self.mask[index] + torch.ones_like(self.mask[index]) * 1e-7 + return mask.to(adj.device) * adj + + def forward(self, adj): + result = [] + for index, _ in enumerate(adj): + result.append(self._mask(index, _)) + return result diff --git a/model/D2STGNN/dynamic_graph_conv/utils/normalizer.py b/model/D2STGNN/dynamic_graph_conv/utils/normalizer.py new file mode 100644 index 0000000..0b832c1 --- /dev/null +++ b/model/D2STGNN/dynamic_graph_conv/utils/normalizer.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn + + +def remove_nan_inf(x): + """移除张量中的nan和inf值""" + x = torch.where(torch.isnan(x) | torch.isinf(x), torch.zeros_like(x), x) + return x + + +class Normalizer(nn.Module): + def __init__(self): + super().__init__() + + def _norm(self, graph): + degree = torch.sum(graph, dim=2) + degree = remove_nan_inf(1 / degree) + degree = torch.diag_embed(degree) + normed_graph = torch.bmm(degree, graph) + return normed_graph + + def forward(self, adj): + return [self._norm(_) for _ in adj] + +class MultiOrder(nn.Module): + def __init__(self, order=2): + super().__init__() + self.order = order + + def _multi_order(self, graph): + graph_ordered = [] + k_1_order = graph # 1 order + mask = torch.eye(graph.shape[1]).to(graph.device) + mask = 1 - mask + graph_ordered.append(k_1_order * mask) + for k in range(2, self.order+1): # e.g., order = 3, k=[2, 3]; order = 2, k=[2] + k_1_order = torch.matmul(k_1_order, graph) + graph_ordered.append(k_1_order * mask) + return graph_ordered + + def forward(self, adj): + return [self._multi_order(_) for _ in adj] diff --git a/model/D2STGNN/inherent_block/forecast.py b/model/D2STGNN/inherent_block/forecast.py new file mode 100644 index 0000000..04ec5ad --- /dev/null +++ b/model/D2STGNN/inherent_block/forecast.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn + +class Forecast(nn.Module): + def __init__(self, hidden_dim, fk_dim, **model_args): + super().__init__() + self.output_seq_len = model_args['seq_len'] + self.model_args = model_args + + self.forecast_fc = nn.Linear(hidden_dim, fk_dim) + + def forward(self, X, RNN_H, Z, transformer_layer, rnn_layer, pe): + [batch_size, _, num_nodes, num_feat] = X.shape + + predict = [Z[-1, :, :].unsqueeze(0)] + for _ in range(int(self.output_seq_len / self.model_args['gap'])-1): + # RNN + _gru = rnn_layer.gru_cell(predict[-1][0], RNN_H[-1]).unsqueeze(0) + RNN_H = torch.cat([RNN_H, _gru], dim=0) + # Positional Encoding + if pe is not None: + RNN_H = pe(RNN_H) + # Transformer + _Z = transformer_layer(_gru, K=RNN_H, V=RNN_H) + predict.append(_Z) + + predict = torch.cat(predict, dim=0) + predict = predict.reshape(-1, batch_size, num_nodes, num_feat) + predict = predict.transpose(0, 1) + predict = self.forecast_fc(predict) + return predict diff --git a/model/D2STGNN/inherent_block/inh_block.py b/model/D2STGNN/inherent_block/inh_block.py new file mode 100644 index 0000000..6d39cd2 --- /dev/null +++ b/model/D2STGNN/inherent_block/inh_block.py @@ -0,0 +1,86 @@ +import math + +import torch +import torch.nn as nn + +from model.D2STGNN.decouple.residual_decomp import ResidualDecomp +from model.D2STGNN.inherent_block.inh_model import RNNLayer, TransformerLayer +from model.D2STGNN.inherent_block.forecast import Forecast + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=None, max_len: int = 5000): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + pe = torch.zeros(max_len, 1, d_model) + pe[:, 0, 0::2] = torch.sin(position * div_term) + pe[:, 0, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + + def forward(self, X): + X = X + self.pe[:X.size(0)] + X = self.dropout(X) + return X + + +class InhBlock(nn.Module): + def __init__(self, hidden_dim, num_heads=4, bias=True, forecast_hidden_dim=256, **model_args): + """Inherent block + + Args: + hidden_dim (int): hidden dimension + num_heads (int, optional): number of heads of MSA. Defaults to 4. + bias (bool, optional): if use bias. Defaults to True. + forecast_hidden_dim (int, optional): forecast branch hidden dimension. Defaults to 256. + """ + super().__init__() + self.num_feat = hidden_dim + self.hidden_dim = hidden_dim + + # inherent model + self.pos_encoder = PositionalEncoding(hidden_dim, model_args['dropout']) + self.rnn_layer = RNNLayer(hidden_dim, model_args['dropout']) + self.transformer_layer = TransformerLayer(hidden_dim, num_heads, model_args['dropout'], bias) + + # forecast branch + self.forecast_block = Forecast(hidden_dim, forecast_hidden_dim, **model_args) + # backcast branch + self.backcast_fc = nn.Linear(hidden_dim, hidden_dim) + # residual decomposition + self.residual_decompose = ResidualDecomp([-1, -1, -1, hidden_dim]) + + def forward(self, hidden_inherent_signal): + """Inherent block, containing the inherent model, forecast branch, backcast branch, and the residual decomposition link. + + Args: + hidden_inherent_signal (torch.Tensor): hidden inherent signal with shape [batch_size, seq_len, num_nodes, num_feat]. + + Returns: + torch.Tensor: the output after the decoupling mechanism (backcast branch and the residual link), which should be fed to the next decouple layer. + Shape: [batch_size, seq_len, num_nodes, hidden_dim]. + torch.Tensor: the output of the forecast branch, which will be used to make final prediction. + Shape: [batch_size, seq_len'', num_nodes, forecast_hidden_dim]. seq_len'' = future_len / gap. + In order to reduce the error accumulation in the AR forecasting strategy, we let each hidden state generate the prediction of gap points, instead of a single point. + """ + + [batch_size, seq_len, num_nodes, num_feat] = hidden_inherent_signal.shape + # inherent model + ## rnn + hidden_states_rnn = self.rnn_layer(hidden_inherent_signal) + ## pe + hidden_states_rnn = self.pos_encoder(hidden_states_rnn) + ## MSA + hidden_states_inh = self.transformer_layer(hidden_states_rnn, hidden_states_rnn, hidden_states_rnn) + + # forecast branch + forecast_hidden = self.forecast_block(hidden_inherent_signal, hidden_states_rnn, hidden_states_inh, self.transformer_layer, self.rnn_layer, self.pos_encoder) + + # backcast branch + hidden_states_inh = hidden_states_inh.reshape(seq_len, batch_size, num_nodes, num_feat) + hidden_states_inh = hidden_states_inh.transpose(0, 1) + backcast_seq = self.backcast_fc(hidden_states_inh) + backcast_seq_res= self.residual_decompose(hidden_inherent_signal, backcast_seq) + + return backcast_seq_res, forecast_hidden diff --git a/model/D2STGNN/inherent_block/inh_model.py b/model/D2STGNN/inherent_block/inh_model.py new file mode 100644 index 0000000..a543d1c --- /dev/null +++ b/model/D2STGNN/inherent_block/inh_model.py @@ -0,0 +1,35 @@ +import torch as th +import torch.nn as nn +from torch.nn import MultiheadAttention + + +class RNNLayer(nn.Module): + def __init__(self, hidden_dim, dropout=None): + super().__init__() + self.hidden_dim = hidden_dim + self.gru_cell = nn.GRUCell(hidden_dim, hidden_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, X): + [batch_size, seq_len, num_nodes, hidden_dim] = X.shape + X = X.transpose(1, 2).reshape(batch_size * num_nodes, seq_len, hidden_dim) + hx = th.zeros_like(X[:, 0, :]) + output = [] + for _ in range(X.shape[1]): + hx = self.gru_cell(X[:, _, :], hx) + output.append(hx) + output = th.stack(output, dim=0) + output = self.dropout(output) + return output + + +class TransformerLayer(nn.Module): + def __init__(self, hidden_dim, num_heads=4, dropout=None, bias=True): + super().__init__() + self.multi_head_self_attention = MultiheadAttention(hidden_dim, num_heads, dropout=dropout, bias=bias) + self.dropout = nn.Dropout(dropout) + + def forward(self, X, K, V): + hidden_states_MSA = self.multi_head_self_attention(X, K, V)[0] + hidden_states_MSA = self.dropout(hidden_states_MSA) + return hidden_states_MSA diff --git a/model/D2STGNN/model_config.json b/model/D2STGNN/model_config.json new file mode 100644 index 0000000..635e3f6 --- /dev/null +++ b/model/D2STGNN/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "D2STGNN", + "module": "model.D2STGNN.D2STGNN", + "entry": "D2STGNN" + } +] \ No newline at end of file diff --git a/train.py b/train.py index db9d8dd..15b3b30 100644 --- a/train.py +++ b/train.py @@ -6,14 +6,15 @@ import utils.initializer as init from dataloader.loader_selector import get_dataloader from trainer.trainer_selector import select_trainer + def read_config(config_path): with open(config_path, "r") as file: config = yaml.safe_load(file) - + # 全局配置 - device = "cpu" # 指定设备为cuda:0 - seed = 2023 # 随机种子 - epochs = 1 # 训练轮数 + device = "cpu" # 指定设备为cuda:0 + seed = 2023 # 随机种子 + epochs = 1 # 训练轮数 # 拷贝项 config["basic"]["device"] = device @@ -23,6 +24,7 @@ def read_config(config_path): config["train"]["epochs"] = epochs return config + def run(config): init.init_seed(config["basic"]["seed"]) model = init.init_model(config) @@ -34,10 +36,15 @@ def run(config): init.create_logs(config) trainer = select_trainer( model, - loss, optimizer, - train_loader, val_loader, test_loader, scaler, + loss, + optimizer, + train_loader, + val_loader, + test_loader, + scaler, config, - lr_scheduler, extra_data, + lr_scheduler, + extra_data, ) # 开始训练 @@ -54,17 +61,20 @@ def run(config): ) trainer.test( model.to(config["basic"]["device"]), - trainer.args, test_loader, scaler, + trainer.args, + test_loader, + scaler, trainer.logger, ) case _: raise ValueError(f"Unsupported mode: {config['basic']['mode']}") - -def main(model, data, debug=False): + + +def main(model_list, data, debug=False): # 我的调试开关,不做测试就填 str(False) # os.environ["TRY"] = str(False) os.environ["TRY"] = str(debug) - + for model in model_list: for dataset in data: config_path = f"./config/{model}/{dataset}.yaml" @@ -77,22 +87,25 @@ def main(model, data, debug=False): except Exception as e: import traceback import sys, traceback + tb_lines = traceback.format_exc().splitlines() # 如果不是AssertionError,才打印完整traceback if not tb_lines[-1].startswith("AssertionError"): traceback.print_exc() - print(f"\n===== {model} on {dataset} failed with error: {e} =====\n") + print( + f"\n===== {model} on {dataset} failed with error: {e} =====\n" + ) else: run(config) - if __name__ == "__main__": # 调试用 # model_list = ["iTransformer", "PatchTST", "HI"] - model_list = ["STNorm"] + model_list = ["D2STGNN"] # model_list = ["PatchTST"] # dataset_list = ["AirQuality"] - dataset_list = ["BJTaxi-InFlow", "BJTaxi-OutFlow"] + # dataset_list = ["BJTaxi-InFlow", "BJTaxi-OutFlow"] # dataset_list = ["AirQuality", "PEMS-BAY", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"] - main(model_list, dataset_list, debug = True) \ No newline at end of file + dataset_list = ["BJTaxi-OutFlow"] + main(model_list, dataset_list, debug=True)