From 97eb39073abc75f861362ba6e0bef390f105a292 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 23 Apr 2025 23:22:50 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0STIDGCN?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/STIDGCN/PEMSD3.yaml | 48 +++++ config/STIDGCN/PEMSD4.yaml | 48 +++++ config/STIDGCN/PEMSD7.yaml | 48 +++++ config/STIDGCN/PEMSD8.yaml | 48 +++++ model/STIDGCN/STIDGCN.py | 368 +++++++++++++++++++++++++++++++++++++ model/model_selector.py | 9 +- 6 files changed, 565 insertions(+), 4 deletions(-) create mode 100644 config/STIDGCN/PEMSD3.yaml create mode 100644 config/STIDGCN/PEMSD4.yaml create mode 100644 config/STIDGCN/PEMSD7.yaml create mode 100644 config/STIDGCN/PEMSD8.yaml create mode 100644 model/STIDGCN/STIDGCN.py diff --git a/config/STIDGCN/PEMSD3.yaml b/config/STIDGCN/PEMSD3.yaml new file mode 100644 index 0000000..c3d8ce2 --- /dev/null +++ b/config/STIDGCN/PEMSD3.yaml @@ -0,0 +1,48 @@ +data: + num_nodes: 358 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 3 + output_dim: 1 + history: 12 + horizon: 12 + granularity: 288 + dropout: 0.1 + channels: 32 + + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 200 + plot: False diff --git a/config/STIDGCN/PEMSD4.yaml b/config/STIDGCN/PEMSD4.yaml new file mode 100644 index 0000000..fb4dae6 --- /dev/null +++ b/config/STIDGCN/PEMSD4.yaml @@ -0,0 +1,48 @@ +data: + num_nodes: 307 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 3 + output_dim: 1 + history: 12 + horizon: 12 + granularity: 288 + dropout: 0.1 + channels: 64 + + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 200 + plot: False diff --git a/config/STIDGCN/PEMSD7.yaml b/config/STIDGCN/PEMSD7.yaml new file mode 100644 index 0000000..4f67ce1 --- /dev/null +++ b/config/STIDGCN/PEMSD7.yaml @@ -0,0 +1,48 @@ +data: + num_nodes: 883 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 3 + output_dim: 1 + history: 12 + horizon: 12 + granularity: 288 + dropout: 0.1 + channels: 128 + + +train: + loss_func: mae + seed: 10 + batch_size: 16 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 200 + plot: False diff --git a/config/STIDGCN/PEMSD8.yaml b/config/STIDGCN/PEMSD8.yaml new file mode 100644 index 0000000..354c9e3 --- /dev/null +++ b/config/STIDGCN/PEMSD8.yaml @@ -0,0 +1,48 @@ +data: + num_nodes: 170 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 3 + output_dim: 1 + history: 12 + horizon: 12 + granularity: 288 + dropout: 0.1 + channels: 96 + + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 200 + plot: False diff --git a/model/STIDGCN/STIDGCN.py b/model/STIDGCN/STIDGCN.py new file mode 100644 index 0000000..8c89fb9 --- /dev/null +++ b/model/STIDGCN/STIDGCN.py @@ -0,0 +1,368 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +class GLU(nn.Module): + def __init__(self, features, dropout=0.1): + super(GLU, self).__init__() + self.conv1 = nn.Conv2d(features, features, (1, 1)) + self.conv2 = nn.Conv2d(features, features, (1, 1)) + self.conv3 = nn.Conv2d(features, features, (1, 1)) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x) + out = x1 * torch.sigmoid(x2) + out = self.dropout(out) + out = self.conv3(out) + return out + + +# class TemporalEmbedding(nn.Module): +# def __init__(self, time, features): +# super(TemporalEmbedding, self).__init__() +# +# self.time = time +# # self.time_day = nn.Parameter(torch.empty(time, features)) +# # nn.init.xavier_uniform_(self.time_day) +# # +# # self.time_week = nn.Parameter(torch.empty(7, features)) +# # nn.init.xavier_uniform_(self.time_week) +# self.time_day = nn.Embedding(time, features) +# self.time_week = nn.Embedding(7, features) +# +# def forward(self, x): +# day_emb = x[..., 1] +# # time_day = self.time_day[(day_emb[:, :, :] * self.time).type(torch.LongTensor)] +# # time_day = time_day.transpose(1, 2).contiguous() +# +# week_emb = x[..., 2] +# # time_week = self.time_week[(week_emb[:, :, :]).type(torch.LongTensor)] +# # time_week = time_week.transpose(1, 2).contiguous() +# +# t_idx = (day_emb[:, -1, :, ] * (self.time - 1)).long() # (B, N) +# d_idx = week_emb[:, -1, :, ].long() # (B, N) +# # time_emb = self.time_embedding(t_idx) # (B, N, hidden_dim) +# # day_emb = self.day_embedding(d_idx) # (B, N, hidden_dim) +# +# tem_emb = t_idx + d_idx +# +# # tem_emb = tem_emb.permute(0, 3, 1, 2) +# +# return tem_emb +class TemporalEmbedding(nn.Module): + def __init__(self, time, features): + super(TemporalEmbedding, self).__init__() + + self.time = time + self.time_day = nn.Parameter(torch.empty(time, features)) + nn.init.xavier_uniform_(self.time_day) + + self.time_week = nn.Parameter(torch.empty(7, features)) + nn.init.xavier_uniform_(self.time_week) + + def forward(self, x): + day_emb = x[..., 1] + time_day = self.time_day[(day_emb[:, :, :] * self.time).type(torch.LongTensor)] + time_day = time_day.transpose(1, 2).contiguous() + + week_emb = x[..., 2] + time_week = self.time_week[(week_emb[:, :, :]).type(torch.LongTensor)] + time_week = time_week.transpose(1, 2).contiguous() + + tem_emb = time_day + time_week + + tem_emb = tem_emb.permute(0,3,1,2) + + return tem_emb + +class Diffusion_GCN(nn.Module): + def __init__(self, channels=128, diffusion_step=1, dropout=0.1): + super().__init__() + self.diffusion_step = diffusion_step + self.conv = nn.Conv2d(diffusion_step * channels, channels, (1, 1)) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, adj): + out = [] + for i in range(0, self.diffusion_step): + if adj.dim() == 3: + x = torch.einsum("bcnt,bnm->bcmt", x, adj).contiguous() + out.append(x) + elif adj.dim() == 2: + x = torch.einsum("bcnt,nm->bcmt", x, adj).contiguous() + out.append(x) + x = torch.cat(out, dim=1) + x = self.conv(x) + output = self.dropout(x) + return output + + +class Graph_Generator(nn.Module): + def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1): + super().__init__() + self.memory = nn.Parameter(torch.randn(channels, num_nodes)) + nn.init.xavier_uniform_(self.memory) + self.fc = nn.Linear(2, 1) + + def forward(self, x): + adj_dyn_1 = torch.softmax( + F.relu( + torch.einsum("bcnt, cm->bnm", x, self.memory).contiguous() + / math.sqrt(x.shape[1]) + ), + -1, + ) + adj_dyn_2 = torch.softmax( + F.relu( + torch.einsum("bcn, bcm->bnm", x.sum(-1), x.sum(-1)).contiguous() + / math.sqrt(x.shape[1]) + ), + -1, + ) + # adj_dyn = (adj_dyn_1 + adj_dyn_2 + adj)/2 + adj_f = torch.cat([(adj_dyn_1).unsqueeze(-1)] + [(adj_dyn_2).unsqueeze(-1)], dim=-1) + adj_f = torch.softmax(self.fc(adj_f).squeeze(), -1) + + topk_values, topk_indices = torch.topk(adj_f, k=int(adj_f.shape[1] * 0.8), dim=-1) + mask = torch.zeros_like(adj_f) + mask.scatter_(-1, topk_indices, 1) + adj_f = adj_f * mask + + return adj_f + + +class DGCN(nn.Module): + def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1, emb=None): + super().__init__() + self.conv = nn.Conv2d(channels, channels, (1, 1)) + self.generator = Graph_Generator(channels, num_nodes, diffusion_step, dropout) + self.gcn = Diffusion_GCN(channels, diffusion_step, dropout) + self.emb = emb + + def forward(self, x): + skip = x + x = self.conv(x) + adj_dyn = self.generator(x) + x = self.gcn(x, adj_dyn) + x = x * self.emb + skip + return x + + +class Splitting(nn.Module): + def __init__(self): + super(Splitting, self).__init__() + + def even(self, x): + return x[:, :, :, ::2] + + def odd(self, x): + return x[:, :, :, 1::2] + + def forward(self, x): + return (self.even(x), self.odd(x)) + + +class IDGCN(nn.Module): + def __init__( + self, + device, + channels=64, + diffusion_step=1, + splitting=True, + num_nodes=170, + dropout=0.2, emb=None + ): + super(IDGCN, self).__init__() + + device = device + self.dropout = dropout + self.num_nodes = num_nodes + self.splitting = splitting + self.split = Splitting() + + Conv1 = [] + Conv2 = [] + Conv3 = [] + Conv4 = [] + pad_l = 3 + pad_r = 3 + + k1 = 5 + k2 = 3 + Conv1 += [ + nn.ReplicationPad2d((pad_l, pad_r, 0, 0)), + nn.Conv2d(channels, channels, kernel_size=(1, k1)), + nn.LeakyReLU(negative_slope=0.01, inplace=True), + nn.Dropout(self.dropout), + nn.Conv2d(channels, channels, kernel_size=(1, k2)), + nn.Tanh(), + ] + Conv2 += [ + nn.ReplicationPad2d((pad_l, pad_r, 0, 0)), + nn.Conv2d(channels, channels, kernel_size=(1, k1)), + nn.LeakyReLU(negative_slope=0.01, inplace=True), + nn.Dropout(self.dropout), + nn.Conv2d(channels, channels, kernel_size=(1, k2)), + nn.Tanh(), + ] + Conv4 += [ + nn.ReplicationPad2d((pad_l, pad_r, 0, 0)), + nn.Conv2d(channels, channels, kernel_size=(1, k1)), + nn.LeakyReLU(negative_slope=0.01, inplace=True), + nn.Dropout(self.dropout), + nn.Conv2d(channels, channels, kernel_size=(1, k2)), + nn.Tanh(), + ] + Conv3 += [ + nn.ReplicationPad2d((pad_l, pad_r, 0, 0)), + nn.Conv2d(channels, channels, kernel_size=(1, k1)), + nn.LeakyReLU(negative_slope=0.01, inplace=True), + nn.Dropout(self.dropout), + nn.Conv2d(channels, channels, kernel_size=(1, k2)), + nn.Tanh(), + ] + + self.conv1 = nn.Sequential(*Conv1) + self.conv2 = nn.Sequential(*Conv2) + self.conv3 = nn.Sequential(*Conv3) + self.conv4 = nn.Sequential(*Conv4) + + self.dgcn = DGCN(channels, num_nodes, diffusion_step, dropout, emb) + + def forward(self, x): + if self.splitting: + (x_even, x_odd) = self.split(x) + else: + (x_even, x_odd) = x + + x1 = self.conv1(x_even) + x1 = self.dgcn(x1) + d = x_odd.mul(torch.tanh(x1)) + + x2 = self.conv2(x_odd) + x2 = self.dgcn(x2) + c = x_even.mul(torch.tanh(x2)) + + x3 = self.conv3(c) + x3 = self.dgcn(x3) + x_odd_update = d + x3 + + x4 = self.conv4(d) + x4 = self.dgcn(x4) + x_even_update = c + x4 + + return (x_even_update, x_odd_update) + + +class IDGCN_Tree(nn.Module): + def __init__( + self, device, channels=64, diffusion_step=1, num_nodes=170, dropout=0.1 + ): + super().__init__() + + self.memory1 = nn.Parameter(torch.randn(channels, num_nodes, 6)) + self.memory2 = nn.Parameter(torch.randn(channels, num_nodes, 3)) + self.memory3 = nn.Parameter(torch.randn(channels, num_nodes, 3)) + + self.IDGCN1 = IDGCN( + device=device, + splitting=True, + channels=channels, + diffusion_step=diffusion_step, + num_nodes=num_nodes, + dropout=dropout, emb=self.memory1 + ) + self.IDGCN2 = IDGCN( + device=device, + splitting=True, + channels=channels, + diffusion_step=diffusion_step, + num_nodes=num_nodes, + dropout=dropout, emb=self.memory2 + ) + self.IDGCN3 = IDGCN( + device=device, + splitting=True, + channels=channels, + diffusion_step=diffusion_step, + num_nodes=num_nodes, + dropout=dropout, emb=self.memory2 + ) + + def concat(self, even, odd): + even = even.permute(3, 1, 2, 0) + odd = odd.permute(3, 1, 2, 0) + len = even.shape[0] + _ = [] + for i in range(len): + _.append(even[i].unsqueeze(0)) + _.append(odd[i].unsqueeze(0)) + return torch.cat(_, 0).permute(3, 1, 2, 0) + + def forward(self, x): + x_even_update1, x_odd_update1 = self.IDGCN1(x) + x_even_update2, x_odd_update2 = self.IDGCN2(x_even_update1) + x_even_update3, x_odd_update3 = self.IDGCN3(x_odd_update1) + concat1 = self.concat(x_even_update2, x_odd_update2) + concat2 = self.concat(x_even_update3, x_odd_update3) + concat0 = self.concat(concat1, concat2) + output = concat0 + x + return output + + +class STIDGCN(nn.Module): + def __init__(self, args): + """ + device, input_dim, num_nodes, channels, granularity, dropout=0.1 + """ + super().__init__() + + device = args['device'] + input_dim = args['input_dim'] + self.num_nodes = args['num_nodes'] + self.output_len = 12 + channels = args['channels'] + granularity = args['granularity'] + dropout = args['dropout'] + diffusion_step = 1 + + self.Temb = TemporalEmbedding(granularity, channels) + + self.start_conv = nn.Conv2d( + in_channels=input_dim, out_channels=channels, kernel_size=(1, 1) + ) + + self.tree = IDGCN_Tree( + device=device, + channels=channels * 2, + diffusion_step=diffusion_step, + num_nodes=self.num_nodes, + dropout=dropout, + ) + + self.glu = GLU(channels * 2, dropout) + + self.regression_layer = nn.Conv2d( + channels * 2, self.output_len, kernel_size=(1, self.output_len) + ) + + def param_num(self): + return sum([param.nelement() for param in self.parameters()]) + + def forward(self, input): + input = input.transpose(1, 3) + x = input + # Encoder + # Data Embedding + time_emb = self.Temb(input.permute(0, 3, 2, 1)) + x = torch.cat([self.start_conv(x)] + [time_emb], dim=1) + # IDGCN_Tree + x = self.tree(x) + # Decoder + gcn = self.glu(x) + x + prediction = self.regression_layer(F.relu(gcn)) + return prediction \ No newline at end of file diff --git a/model/model_selector.py b/model/model_selector.py index 54b8c5f..9067833 100644 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -13,8 +13,8 @@ from model.STFGNN.STFGNN import STFGNN from model.STSGCN.STSGCN import STSGCN from model.STGODE.STGODE import ODEGCN from model.PDG2SEQ.PDG2Seq import PDG2Seq -from model.EXP.EXP import EXP -from model.EXPB.EXP_b import EXPB +from model.STIDGCN.STIDGCN import STIDGCN + def model_selector(model): match model['type']: @@ -33,6 +33,7 @@ def model_selector(model): case 'STSGCN': return STSGCN(model) case 'STGODE': return ODEGCN(model) case 'PDG2SEQ': return PDG2Seq(model) - case 'EXP': return EXP(model) - case 'EXPB': return EXPB(model) + case 'STIDGCN': return STIDGCN(model) + # case 'EXP': return EXP(model) + # case 'EXPB': return EXPB(model)