From bc9a2667c232d7768299dfe7686d2075a3ae77e2 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Mon, 7 Apr 2025 17:05:59 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E4=BA=86=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E8=92=B8=E9=A6=8FSTMLP=20=E7=8E=B0=E5=9C=A8Trainer?= =?UTF-8?q?=E6=AF=8F=E6=AC=A1epoch=E5=AE=8C=E5=90=8E=E9=83=BD=E4=BC=9A?= =?UTF-8?q?=E4=BF=9D=E5=AD=98=E6=A8=A1=E5=9E=8Bcheckpoint=20=E5=85=B6?= =?UTF-8?q?=E4=B8=ADSTMLP=E4=BC=9A=E8=87=AA=E5=8A=A8=E6=95=99=E5=B8=88?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=88=B0pre-train=20=E6=A0=B9=E6=8D=AE?= =?UTF-8?q?=E6=95=99=E5=B8=88=E6=A8=A1=E5=9E=8B=E7=9A=84=E5=AD=98=E5=9C=A8?= =?UTF-8?q?=E6=83=85=E5=86=B5=E5=90=AF=E5=8A=A8/=E9=A2=84=E8=AE=AD?= =?UTF-8?q?=E7=BB=83or=E8=92=B8=E9=A6=8F=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + config/STMLP/PEMSD3.yaml | 66 ++++++++ config/STMLP/PEMSD4.yaml | 67 ++++++++ config/STMLP/PEMSD7.yaml | 66 ++++++++ config/STMLP/PEMSD8.yaml | 66 ++++++++ lib/Download_data.py | 2 +- model/STMLP/STMLP.py | 307 ++++++++++++++++++++++++++++++++++++ model/model_selector.py | 6 +- run.py | 3 - trainer/DCRNN_Trainer.py | 4 - trainer/PDG2SEQ_Trainer.py | 4 - trainer/STMLP_Trainer.py | 261 ++++++++++++++++++++++++++++++ trainer/Trainer.py | 4 +- trainer/trainer_selector.py | 3 + transfer_guide.md | 2 +- 15 files changed, 844 insertions(+), 18 deletions(-) create mode 100644 config/STMLP/PEMSD3.yaml create mode 100644 config/STMLP/PEMSD4.yaml create mode 100644 config/STMLP/PEMSD7.yaml create mode 100644 config/STMLP/PEMSD8.yaml create mode 100644 model/STMLP/STMLP.py create mode 100644 trainer/STMLP_Trainer.py diff --git a/.gitignore b/.gitignore index a68924d..d67c4d1 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ experiments/ *.pkl data/ pretrain/ +pre-train/ # ---> Python # Byte-compiled / optimized / DLL files diff --git a/config/STMLP/PEMSD3.yaml b/config/STMLP/PEMSD3.yaml new file mode 100644 index 0000000..eee7a15 --- /dev/null +++ b/config/STMLP/PEMSD3.yaml @@ -0,0 +1,66 @@ +data: + num_nodes: 358 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 1 + output_dim: 1 + input_window: 12 + output_window: 12 + gcn_true: true + buildA_true: true + gcn_depth: 2 + dropout: 0.3 + subgraph_size: 20 + node_dim: 40 + dilation_exponential: 1 + conv_channels: 32 + residual_channels: 32 + skip_channels: 64 + end_channels: 128 + layers: 3 + propalpha: 0.05 + tanhalpha: 3 + layer_norm_affline: true + use_curriculum_learning: true + step_size1: 2500 + task_level: 0 + num_split: 1 + step_size2: 100 + model_type: stmlp + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + teacher_stu: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 2000 + plot: False diff --git a/config/STMLP/PEMSD4.yaml b/config/STMLP/PEMSD4.yaml new file mode 100644 index 0000000..c416fc4 --- /dev/null +++ b/config/STMLP/PEMSD4.yaml @@ -0,0 +1,67 @@ +data: + num_nodes: 307 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 1 + output_dim: 1 + input_window: 12 + output_window: 12 + gcn_true: true + buildA_true: true + gcn_depth: 2 + dropout: 0.3 + subgraph_size: 20 + node_dim: 40 + dilation_exponential: 1 + conv_channels: 32 + residual_channels: 32 + skip_channels: 64 + end_channels: 128 + layers: 3 + propalpha: 0.05 + tanhalpha: 3 + layer_norm_affline: true + use_curriculum_learning: true + step_size1: 2500 + task_level: 0 + num_split: 1 + step_size2: 100 + model_type: stmlp + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + teacher: True + teacher_stu: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 2000 + plot: False diff --git a/config/STMLP/PEMSD7.yaml b/config/STMLP/PEMSD7.yaml new file mode 100644 index 0000000..14e6382 --- /dev/null +++ b/config/STMLP/PEMSD7.yaml @@ -0,0 +1,66 @@ +data: + num_nodes: 883 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 1 + output_dim: 1 + input_window: 12 + output_window: 12 + gcn_true: true + buildA_true: true + gcn_depth: 2 + dropout: 0.3 + subgraph_size: 20 + node_dim: 40 + dilation_exponential: 1 + conv_channels: 32 + residual_channels: 32 + skip_channels: 64 + end_channels: 128 + layers: 3 + propalpha: 0.05 + tanhalpha: 3 + layer_norm_affline: true + use_curriculum_learning: true + step_size1: 2500 + task_level: 0 + num_split: 1 + step_size2: 100 + model_type: stmlp + +train: + loss_func: mae + seed: 10 + batch_size: 16 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + teacher_stu: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 2000 + plot: False diff --git a/config/STMLP/PEMSD8.yaml b/config/STMLP/PEMSD8.yaml new file mode 100644 index 0000000..bceffa5 --- /dev/null +++ b/config/STMLP/PEMSD8.yaml @@ -0,0 +1,66 @@ +data: + num_nodes: 170 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 1 + output_dim: 1 + input_window: 12 + output_window: 12 + gcn_true: true + buildA_true: true + gcn_depth: 2 + dropout: 0.3 + subgraph_size: 20 + node_dim: 40 + dilation_exponential: 1 + conv_channels: 32 + residual_channels: 32 + skip_channels: 64 + end_channels: 128 + layers: 3 + propalpha: 0.05 + tanhalpha: 3 + layer_norm_affline: true + use_curriculum_learning: true + step_size1: 2500 + task_level: 0 + num_split: 1 + step_size2: 100 + model_type: stmlp + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + teacher_stu: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 2000 + plot: False diff --git a/lib/Download_data.py b/lib/Download_data.py index 9cc0006..ed7c929 100644 --- a/lib/Download_data.py +++ b/lib/Download_data.py @@ -121,7 +121,7 @@ def download_kaggle_data(current_dir): 如果目标文件夹已存在,会覆盖冲突的文件。 """ try: - print("正在下载 KaggleHub 数据集...") + print("正在下载 PEMS 数据集...") path = kagglehub.dataset_download("elmahy/pems-dataset") # print("Path to KaggleHub dataset files:", path) diff --git a/model/STMLP/STMLP.py b/model/STMLP/STMLP.py new file mode 100644 index 0000000..8af1134 --- /dev/null +++ b/model/STMLP/STMLP.py @@ -0,0 +1,307 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init +from data.get_adj import get_adj +import numbers + + +# --- 基础算子 --- +class NConv(nn.Module): + def forward(self, x, adj): + return torch.einsum('ncwl,vw->ncvl', (x, adj)).contiguous() + + +class DyNconv(nn.Module): + def forward(self, x, adj): + return torch.einsum('ncvl,nvwl->ncwl', (x, adj)).contiguous() + + +class Linear(nn.Module): + def __init__(self, c_in, c_out, bias=True): + super().__init__() + self.mlp = nn.Conv2d(c_in, c_out, kernel_size=1, bias=bias) + + def forward(self, x): + return self.mlp(x) + + +class Prop(nn.Module): + def __init__(self, c_in, c_out, gdep, dropout, alpha): + super().__init__() + self.nconv = NConv() + self.mlp = Linear(c_in, c_out) + self.gdep, self.dropout, self.alpha = gdep, dropout, alpha + + def forward(self, x, adj): + adj = adj + torch.eye(adj.size(0), device=x.device) + d = adj.sum(1) + a = adj / d.view(-1, 1) + h = x + for _ in range(self.gdep): + h = self.alpha * x + (1 - self.alpha) * self.nconv(h, a) + return self.mlp(h) + + +class MixProp(nn.Module): + def __init__(self, c_in, c_out, gdep, dropout, alpha): + super().__init__() + self.nconv = NConv() + self.mlp = Linear((gdep + 1) * c_in, c_out) + self.gdep, self.dropout, self.alpha = gdep, dropout, alpha + + def forward(self, x, adj): + adj = adj + torch.eye(adj.size(0), device=x.device) + d = adj.sum(1) + a = adj / d.view(-1, 1) + out = [x] + h = x + for _ in range(self.gdep): + h = self.alpha * x + (1 - self.alpha) * self.nconv(h, a) + out.append(h) + return self.mlp(torch.cat(out, dim=1)) + + +class DyMixprop(nn.Module): + def __init__(self, c_in, c_out, gdep, dropout, alpha): + super().__init__() + self.nconv = DyNconv() + self.mlp1 = Linear((gdep + 1) * c_in, c_out) + self.mlp2 = Linear((gdep + 1) * c_in, c_out) + self.gdep, self.dropout, self.alpha = gdep, dropout, alpha + self.lin1, self.lin2 = Linear(c_in, c_in), Linear(c_in, c_in) + + def forward(self, x): + x1 = torch.tanh(self.lin1(x)) + x2 = torch.tanh(self.lin2(x)) + adj = self.nconv(x1.transpose(2, 1), x2) + adj0 = torch.softmax(adj, dim=2) + adj1 = torch.softmax(adj.transpose(2, 1), dim=2) + # 两条分支 + out1, out2 = [x], [x] + h = x + for _ in range(self.gdep): + h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj0) + out1.append(h) + h = x + for _ in range(self.gdep): + h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj1) + out2.append(h) + return self.mlp1(torch.cat(out1, dim=1)) + self.mlp2(torch.cat(out2, dim=1)) + + +class DilatedInception(nn.Module): + def __init__(self, cin, cout, dilation_factor=2): + super().__init__() + self.kernels = [2, 3, 6, 7] + cout_each = int(cout / len(self.kernels)) + self.convs = nn.ModuleList([nn.Conv2d(cin, cout_each, kernel_size=(1, k), dilation=(1, dilation_factor)) + for k in self.kernels]) + + def forward(self, x): + outs = [conv(x)[..., -self.convs[-1](x).size(3):] for conv in self.convs] + return torch.cat(outs, dim=1) + + +class GraphConstructor(nn.Module): + def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): + super().__init__() + self.nnodes, self.k, self.dim, self.alpha, self.device = nnodes, k, dim, alpha, device + self.static_feat = static_feat + if static_feat is not None: + xd = static_feat.shape[1] + self.lin1, self.lin2 = nn.Linear(xd, dim), nn.Linear(xd, dim) + else: + self.emb1 = nn.Embedding(nnodes, dim) + self.emb2 = nn.Embedding(nnodes, dim) + self.lin1, self.lin2 = nn.Linear(dim, dim), nn.Linear(dim, dim) + + def forward(self, idx): + if self.static_feat is None: + vec1, vec2 = self.emb1(idx), self.emb2(idx) + else: + vec1 = vec2 = self.static_feat[idx, :] + vec1 = torch.tanh(self.alpha * self.lin1(vec1)) + vec2 = torch.tanh(self.alpha * self.lin2(vec2)) + a = torch.mm(vec1, vec2.transpose(1, 0)) - torch.mm(vec2, vec1.transpose(1, 0)) + adj = F.relu(torch.tanh(self.alpha * a)) + mask = torch.zeros(idx.size(0), idx.size(0), device=self.device) + s1, t1 = adj.topk(self.k, 1) + mask.scatter_(1, t1, s1.new_ones(s1.size())) + return adj * mask + + +class LayerNorm(nn.Module): + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] + + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super().__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape, self.eps, self.elementwise_affine = tuple(normalized_shape), eps, elementwise_affine + if elementwise_affine: + self.weight = nn.Parameter(torch.Tensor(*normalized_shape)) + self.bias = nn.Parameter(torch.Tensor(*normalized_shape)) + init.ones_(self.weight); + init.zeros_(self.bias) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + + def forward(self, x, idx): + if self.elementwise_affine: + return F.layer_norm(x, tuple(x.shape[1:]), self.weight[:, idx, :], self.bias[:, idx, :], self.eps) + else: + return F.layer_norm(x, tuple(x.shape[1:]), self.weight, self.bias, self.eps) + + def extra_repr(self): + return f'{self.normalized_shape}, eps={self.eps}, elementwise_affine={self.elementwise_affine}' + + +# --- 合并后的模型类,支持 teacher 与 stmlp 两种分支 --- +class STMLP(nn.Module): + def __init__(self, args): + super().__init__() + # 参数从字典中读取 + self.adj_mx = get_adj(args) + self.num_nodes = args['num_nodes'] + self.feature_dim = args['input_dim'] + + self.input_window = args['input_window'] + self.output_window = args['output_window'] + self.output_dim = args['output_dim'] + self.device = args['device'] + + self.gcn_true = args['gcn_true'] + self.buildA_true = args['buildA_true'] + self.gcn_depth = args['gcn_depth'] + self.dropout = args['dropout'] + self.subgraph_size = args['subgraph_size'] + self.node_dim = args['node_dim'] + self.dilation_exponential = args['dilation_exponential'] + + self.conv_channels = args['conv_channels'] + self.residual_channels = args['residual_channels'] + self.skip_channels = args['skip_channels'] + self.end_channels = args['end_channels'] + + self.layers = args['layers'] + self.propalpha = args['propalpha'] + self.tanhalpha = args['tanhalpha'] + self.layer_norm_affline = args['layer_norm_affline'] + + self.model_type = args['model_type'] # 'teacher' 或 'stmlp' + self.idx = torch.arange(self.num_nodes).to(self.device) + self.predefined_A = None if self.adj_mx is None else (torch.tensor(self.adj_mx) - torch.eye(self.num_nodes)).to( + self.device) + self.static_feat = None + + # transformer(保留原有结构) + self.encoder_layer = nn.TransformerEncoderLayer(d_model=12, nhead=4, batch_first=True) + self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=3) + + # 构建各层 + self.start_conv = nn.Conv2d(self.feature_dim, self.residual_channels, kernel_size=1) + self.gc = GraphConstructor(self.num_nodes, self.subgraph_size, self.node_dim, self.device, alpha=self.tanhalpha, + static_feat=self.static_feat) + # 计算 receptive_field + kernel_size = 7 + if self.dilation_exponential > 1: + self.receptive_field = int( + self.output_dim + (kernel_size - 1) * (self.dilation_exponential ** self.layers - 1) / ( + self.dilation_exponential - 1)) + else: + self.receptive_field = self.layers * (kernel_size - 1) + self.output_dim + + self.filter_convs = nn.ModuleList() + self.gate_convs = nn.ModuleList() + self.residual_convs = nn.ModuleList() + self.skip_convs = nn.ModuleList() + self.norm = nn.ModuleList() + self.stu_mlp = nn.ModuleList([nn.Sequential(nn.Linear(c, c), nn.Linear(c, c), nn.Linear(c, c)) + for c in [13, 7, 1]]) + if self.gcn_true: + self.gconv1 = nn.ModuleList() + self.gconv2 = nn.ModuleList() + + new_dilation = 1 + for i in range(1): + rf_size_i = int(1 + i * (kernel_size - 1) * (self.dilation_exponential ** self.layers - 1) / ( + self.dilation_exponential - 1)) if self.dilation_exponential > 1 else i * self.layers * ( + kernel_size - 1) + 1 + for j in range(1, self.layers + 1): + rf_size_j = int(rf_size_i + (kernel_size - 1) * (self.dilation_exponential ** j - 1) / ( + self.dilation_exponential - 1)) if self.dilation_exponential > 1 else rf_size_i + j * ( + kernel_size - 1) + self.filter_convs.append( + DilatedInception(self.residual_channels, self.conv_channels, dilation_factor=new_dilation)) + self.gate_convs.append( + DilatedInception(self.residual_channels, self.conv_channels, dilation_factor=new_dilation)) + self.residual_convs.append(nn.Conv2d(self.conv_channels, self.residual_channels, kernel_size=1)) + k_size = (1, self.input_window - rf_size_j + 1) if self.input_window > self.receptive_field else ( + 1, self.receptive_field - rf_size_j + 1) + self.skip_convs.append(nn.Conv2d(self.conv_channels, self.skip_channels, kernel_size=k_size)) + if self.gcn_true: + self.gconv1.append(MixProp(self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout, + self.propalpha)) + self.gconv2.append(MixProp(self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout, + self.propalpha)) + norm_size = (self.residual_channels, self.num_nodes, + self.input_window - rf_size_j + 1) if self.input_window > self.receptive_field else ( + self.residual_channels, self.num_nodes, self.receptive_field - rf_size_j + 1) + self.norm.append(LayerNorm(norm_size, elementwise_affine=self.layer_norm_affline)) + new_dilation *= self.dilation_exponential + + self.end_conv_1 = nn.Conv2d(self.skip_channels, self.end_channels, kernel_size=1, bias=True) + self.end_conv_2 = nn.Conv2d(self.end_channels, self.output_window, kernel_size=1, bias=True) + k0 = (1, self.input_window) if self.input_window > self.receptive_field else (1, self.receptive_field) + self.skip0 = nn.Conv2d(self.feature_dim, self.skip_channels, kernel_size=k0, bias=True) + kE = (1, self.input_window - self.receptive_field + 1) if self.input_window > self.receptive_field else (1, 1) + self.skipE = nn.Conv2d(self.residual_channels, self.skip_channels, kernel_size=kE, bias=True) + # 最后输出分支,根据模型类型选择不同的头 + if self.model_type == 'teacher': + self.tt_linear1 = nn.Linear(self.residual_channels, self.input_window) + self.tt_linear2 = nn.Linear(1, 32) + self.ss_linear1 = nn.Linear(self.residual_channels, self.input_window) + self.ss_linear2 = nn.Linear(1, 32) + else: # stmlp + self.out_linear1 = nn.Linear(self.residual_channels, self.input_window) + self.out_linear2 = nn.Linear(1, 32) + + def forward(self, source, idx=None): + source = source[..., 0:1] + sout, tout = [], [] + inputs = source.transpose(1, 3) + assert inputs.size(3) == self.input_window, 'input sequence length mismatch' + if self.input_window < self.receptive_field: + inputs = F.pad(inputs, (self.receptive_field - self.input_window, 0, 0, 0)) + if self.gcn_true: + adp = self.gc(self.idx if idx is None else idx) if self.buildA_true else self.predefined_A + x = self.start_conv(inputs) + skip = self.skip0(F.dropout(inputs, self.dropout, training=self.training)) + for i in range(self.layers): + residual = x + filters = torch.tanh(self.filter_convs[i](x)) + gate = torch.sigmoid(self.gate_convs[i](x)) + x = F.dropout(filters * gate, self.dropout, training=self.training) + tout.append(x) + s = self.skip_convs[i](x) + skip = s + skip + if self.gcn_true: + x = self.gconv1[i](x, adp) + self.gconv2[i](x, adp.transpose(1, 0)) + else: + x = self.stu_mlp[i](x) + x = x + residual[:, :, :, -x.size(3):] + x = self.norm[i](x, self.idx if idx is None else idx) + sout.append(x) + skip = self.skipE(x) + skip + x = F.relu(skip) + x = F.relu(self.end_conv_1(x)) + x = self.end_conv_2(x) + if self.model_type == 'teacher': + ttout = self.tt_linear2(self.tt_linear1(tout[-1].transpose(1, 3)).transpose(1, 3)) + ssout = self.ss_linear2(self.ss_linear1(sout[-1].transpose(1, 3)).transpose(1, 3)) + return x, ttout, ssout + else: + x_ = self.out_linear2(self.out_linear1(tout[-1].transpose(1, 3)).transpose(1, 3)) + return x, x_, x diff --git a/model/model_selector.py b/model/model_selector.py index 54b8c5f..0117619 100644 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -13,8 +13,7 @@ from model.STFGNN.STFGNN import STFGNN from model.STSGCN.STSGCN import STSGCN from model.STGODE.STGODE import ODEGCN from model.PDG2SEQ.PDG2Seq import PDG2Seq -from model.EXP.EXP import EXP -from model.EXPB.EXP_b import EXPB +from model.STMLP.STMLP import STMLP def model_selector(model): match model['type']: @@ -33,6 +32,5 @@ def model_selector(model): case 'STSGCN': return STSGCN(model) case 'STGODE': return ODEGCN(model) case 'PDG2SEQ': return PDG2Seq(model) - case 'EXP': return EXP(model) - case 'EXPB': return EXPB(model) + case 'STMLP': return STMLP(model) diff --git a/run.py b/run.py index 7499708..082fe05 100644 --- a/run.py +++ b/run.py @@ -17,9 +17,6 @@ from dataloader.loader_selector import get_dataloader from trainer.trainer_selector import select_trainer import yaml - - - def main(): args = parse_args() diff --git a/trainer/DCRNN_Trainer.py b/trainer/DCRNN_Trainer.py index ecc4eb0..97a8290 100644 --- a/trainer/DCRNN_Trainer.py +++ b/trainer/DCRNN_Trainer.py @@ -160,10 +160,6 @@ class Trainer: y_pred = torch.cat(y_pred, dim=0) y_true = torch.cat(y_true, dim=0) - # 你在这里需要把y_pred和y_true保存下来 - # torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1] - # torch.save(y_true, "./test/PEMS08/y_true.pt") # [3566,12,170,1] - for t in range(y_true.shape[1]): mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], args['mae_thresh'], args['mape_thresh']) diff --git a/trainer/PDG2SEQ_Trainer.py b/trainer/PDG2SEQ_Trainer.py index 00750a1..bde4801 100644 --- a/trainer/PDG2SEQ_Trainer.py +++ b/trainer/PDG2SEQ_Trainer.py @@ -161,10 +161,6 @@ class Trainer: y_pred = torch.cat(y_pred, dim=0) y_true = torch.cat(y_true, dim=0) - # 你在这里需要把y_pred和y_true保存下来 - # torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1] - # torch.save(y_true, "./test/PEMS08/y_true.pt") # [3566,12,170,1] - for t in range(y_true.shape[1]): mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], args['mae_thresh'], args['mape_thresh']) diff --git a/trainer/STMLP_Trainer.py b/trainer/STMLP_Trainer.py new file mode 100644 index 0000000..6489221 --- /dev/null +++ b/trainer/STMLP_Trainer.py @@ -0,0 +1,261 @@ +import math +import os +import sys +import time +import copy +import torch.nn.functional as F +import torch +from torch import nn + +from tqdm import tqdm +from lib.logger import get_logger +from lib.loss_function import all_metrics +from model.STMLP.STMLP import STMLP + + +class Trainer: + def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, + scaler, args, lr_scheduler=None): + self.model = model + self.loss = loss + self.optimizer = optimizer + self.train_loader = train_loader + self.val_loader = val_loader + self.test_loader = test_loader + self.scaler = scaler + self.args = args['train'] + self.lr_scheduler = lr_scheduler + self.train_per_epoch = len(train_loader) + self.val_per_epoch = len(val_loader) if val_loader else 0 + + # Paths for saving models and logs + self.best_path = os.path.join(self.args['log_dir'], 'best_model.pth') + self.best_test_path = os.path.join(self.args['log_dir'], 'best_test_model.pth') + self.loss_figure_path = os.path.join(self.args['log_dir'], 'loss.png') + self.pretrain_dir = f'./pre-train/{args["model"]["type"]}/{args["data"]["type"]}' + self.pretrain_path = os.path.join(self.pretrain_dir, 'best_model.pth') + self.pretrain_best_path = os.path.join(self.pretrain_dir, 'best_test_model.pth') + + # Initialize logger + if not os.path.isdir(self.args['log_dir']) and not self.args['debug']: + os.makedirs(self.args['log_dir'], exist_ok=True) + if not os.path.isdir(self.pretrain_dir) and not self.args['debug']: + os.makedirs(self.pretrain_dir, exist_ok=True) + self.logger = get_logger(self.args['log_dir'], name=self.model.__class__.__name__, debug=self.args['debug']) + self.logger.info(f"Experiment log path in: {self.args['log_dir']}") + + if self.args['teacher_stu']: + self.tmodel = self.loadTeacher(args) + else: + self.logger.info(f"当前使用预训练模式,预训练后请移动教师模型到" + f"./pre-train/{args['model']['type']}/{args['data']['type']}/best_model.pth" + f"然后在config中配置train.teacher_stu模式为True开启蒸馏模式") + + + def _run_epoch(self, epoch, dataloader, mode): + # self.tmodel.eval() + if mode == 'train': + self.model.train() + optimizer_step = True + else: + self.model.eval() + optimizer_step = False + + total_loss = 0 + epoch_time = time.time() + + with torch.set_grad_enabled(optimizer_step): + with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar: + for batch_idx, (data, target) in enumerate(dataloader): + if self.args['teacher_stu']: + label = target[..., :self.args['output_dim']] + output, out_, _ = self.model(data) + gout, tout, sout = self.tmodel(data) + + if self.args['real_value']: + output = self.scaler.inverse_transform(output) + + loss1 = self.loss(output, label) + scl = self.loss_cls(out_, sout) + kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True).cuda() + gout = F.log_softmax(gout, dim=-1).cuda() + mlp_emb_ = F.log_softmax(output, dim=-1).cuda() + tkloss = kl_loss(mlp_emb_.cuda().float(), gout.cuda().float()) + loss = loss1 + 10 * tkloss + 1 * scl + + else: + label = target[..., :self.args['output_dim']] + output, out_, _ = self.model(data) + + if self.args['real_value']: + output = self.scaler.inverse_transform(output) + + loss = self.loss(output, label) + + if optimizer_step and self.optimizer is not None: + self.optimizer.zero_grad() + loss.backward() + + if self.args['grad_norm']: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm']) + self.optimizer.step() + + total_loss += loss.item() + + if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0: + self.logger.info( + f'Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {loss.item():.6f}') + + # 更新 tqdm 的进度 + pbar.update(1) + pbar.set_postfix(loss=loss.item()) + + avg_loss = total_loss / len(dataloader) + self.logger.info( + f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s') + return avg_loss + + def train_epoch(self, epoch): + return self._run_epoch(epoch, self.train_loader, 'train') + + def val_epoch(self, epoch): + return self._run_epoch(epoch, self.val_loader or self.test_loader, 'val') + + def test_epoch(self, epoch): + return self._run_epoch(epoch, self.test_loader, 'test') + + def train(self): + best_model, best_test_model = None, None + best_loss, best_test_loss = float('inf'), float('inf') + not_improved_count = 0 + + self.logger.info("Training process started") + for epoch in range(1, self.args['epochs'] + 1): + train_epoch_loss = self.train_epoch(epoch) + val_epoch_loss = self.val_epoch(epoch) + test_epoch_loss = self.test_epoch(epoch) + + if train_epoch_loss > 1e6: + self.logger.warning('Gradient explosion detected. Ending...') + break + + if val_epoch_loss < best_loss: + best_loss = val_epoch_loss + not_improved_count = 0 + best_model = copy.deepcopy(self.model.state_dict()) + torch.save(best_model, self.best_path) + torch.save(best_model, self.pretrain_path) + self.logger.info('Best validation model saved!') + else: + not_improved_count += 1 + + if self.args['early_stop'] and not_improved_count == self.args['early_stop_patience']: + self.logger.info( + f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops.") + break + + if test_epoch_loss < best_test_loss: + best_test_loss = test_epoch_loss + best_test_model = copy.deepcopy(self.model.state_dict()) + torch.save(best_test_model, self.best_test_path) + torch.save(best_model, self.pretrain_best_path) + + if not self.args['debug']: + torch.save(best_model, self.best_path) + torch.save(best_test_model, self.best_test_path) + self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}") + + self._finalize_training(best_model, best_test_model) + + def _finalize_training(self, best_model, best_test_model): + self.model.load_state_dict(best_model) + self.logger.info("Testing on best validation model") + self.test(self.model, self.args, self.test_loader, self.scaler, self.logger) + + self.model.load_state_dict(best_test_model) + self.logger.info("Testing on best test model") + self.test(self.model, self.args, self.test_loader, self.scaler, self.logger) + + def loadTeacher(self, args): + model_path = f'./pre-train/{args["model"]["type"]}/{args["data"]["type"]}/best_model.pth' + try: + # 尝试加载教师模型权重 + state_dict = torch.load(model_path) + self.logger.info(f"成功加载教师模型权重: {model_path}") + + # 初始化并返回教师模型 + args['model']['model_type'] = 'teacher' + tmodel = STMLP(args['model']) + tmodel = tmodel.to(args['device']) + tmodel.load_state_dict(state_dict, strict=False) + return tmodel + + except FileNotFoundError: + # 如果找不到权重文件,记录日志并修改 args + self.logger.error( + f"未找到教师模型权重文件: {model_path}。切换到预训练模式训练老师权重。\n" + f"在预训练完成后,再次启动模型则为蒸馏模式") + self.args['teacher_stu'] = False + return None + + + def loss_cls(self, x1, x2): + temperature = 0.05 + x1 = F.normalize(x1, p=2, dim=-1) + x2 = F.normalize(x2, p=2, dim=-1) + weight = F.cosine_similarity(x1, x2, dim=-1) + batch_size = x1.size()[0] + # neg score + out = torch.cat([x1, x2], dim=0) + neg = torch.exp(torch.matmul(out, out.transpose(2, 3).contiguous()) / temperature) + + pos = torch.exp(torch.sum(x1 * x2, dim=-1) * weight / temperature) + # pos = torch.exp(torch.sum(x1 * x2, dim=-1) / temperature) + pos = torch.cat([pos, pos], dim=0).sum(dim=1) + + Ng = neg.sum(dim=-1).sum(dim=1) + + loss = (- torch.log(pos / (pos + Ng))).mean() + + return loss + + @staticmethod + def test(model, args, data_loader, scaler, logger, path=None): + if path: + checkpoint = torch.load(path) + model.load_state_dict(checkpoint['state_dict']) + model.to(args['device']) + + model.eval() + y_pred, y_true = [], [] + + with torch.no_grad(): + for data, target in data_loader: + label = target[..., :args['output_dim']] + output, _, _ = model(data) + y_pred.append(output) + y_true.append(label) + + if args['real_value']: + y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) + else: + y_pred = torch.cat(y_pred, dim=0) + y_true = torch.cat(y_true, dim=0) + + # 你在这里需要把y_pred和y_true保存下来 + # torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1] + # torch.save(y_true, "./test/PEMSD8/y_true.pt") # [3566,12,170,1] + + for t in range(y_true.shape[1]): + mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], + args['mae_thresh'], args['mape_thresh']) + logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + + mae, rmse, mape = all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh']) + logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + + @staticmethod + def _compute_sampling_threshold(global_step, k): + return k / (k + math.exp(global_step / k)) + + diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 5613870..45539c1 100644 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -107,6 +107,7 @@ class Trainer: best_loss = val_epoch_loss not_improved_count = 0 best_model = copy.deepcopy(self.model.state_dict()) + torch.save(best_model, self.best_path) self.logger.info('Best validation model saved!') else: not_improved_count += 1 @@ -118,6 +119,7 @@ class Trainer: if test_epoch_loss < best_test_loss: best_test_loss = test_epoch_loss + torch.save(best_test_model, self.best_test_path) best_test_model = copy.deepcopy(self.model.state_dict()) if not self.args['debug']: @@ -161,7 +163,7 @@ class Trainer: # 你在这里需要把y_pred和y_true保存下来 # torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1] - # torch.save(y_true, "./test/PEMS08/y_true.pt") # [3566,12,170,1] + # torch.save(y_true, "./test/PEMSD8/y_true.pt") # [3566,12,170,1] for t in range(y_true.shape[1]): mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], diff --git a/trainer/trainer_selector.py b/trainer/trainer_selector.py index eaad3ab..c52ca67 100644 --- a/trainer/trainer_selector.py +++ b/trainer/trainer_selector.py @@ -2,6 +2,7 @@ from trainer.Trainer import Trainer from trainer.cdeTrainer.cdetrainer import Trainer as cdeTrainer from trainer.DCRNN_Trainer import Trainer as DCRNN_Trainer from trainer.PDG2SEQ_Trainer import Trainer as PDG2SEQ_Trainer +from trainer.STMLP_Trainer import Trainer as STMLP_Trainer def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, @@ -13,5 +14,7 @@ def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader lr_scheduler) case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], lr_scheduler) + case 'STMLP': return STMLP_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, + lr_scheduler) case _: return Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], lr_scheduler) diff --git a/transfer_guide.md b/transfer_guide.md index 43b4349..8d4bde1 100644 --- a/transfer_guide.md +++ b/transfer_guide.md @@ -299,7 +299,7 @@ def read_data(args): 'pems03': ['PEMS03/pems03.npz', 'PEMS03/distance.csv'], 'pems04': ['PEMS04/pems04.npz', 'PEMS04/distance.csv'], 'pems07': ['PEMS07/pems07.npz', 'PEMS07/distance.csv'], - 'pems08': ['PEMS08/pems08.npz', 'PEMS08/distance.csv'], + 'pems08': ['PEMSD8/pems08.npz', 'PEMSD8/distance.csv'], 'pemsbay': ['PEMSBAY/pems_bay.npz', 'PEMSBAY/distance.csv'], 'pemsD7M': ['PeMSD7M/PeMSD7M.npz', 'PeMSD7M/distance.csv'], 'pemsD7L': ['PeMSD7L/PeMSD7L.npz', 'PeMSD7L/distance.csv'] From 97eb39073abc75f861362ba6e0bef390f105a292 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 23 Apr 2025 23:22:50 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E6=B7=BB=E5=8A=A0STIDGCN?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/STIDGCN/PEMSD3.yaml | 48 +++++ config/STIDGCN/PEMSD4.yaml | 48 +++++ config/STIDGCN/PEMSD7.yaml | 48 +++++ config/STIDGCN/PEMSD8.yaml | 48 +++++ model/STIDGCN/STIDGCN.py | 368 +++++++++++++++++++++++++++++++++++++ model/model_selector.py | 9 +- 6 files changed, 565 insertions(+), 4 deletions(-) create mode 100644 config/STIDGCN/PEMSD3.yaml create mode 100644 config/STIDGCN/PEMSD4.yaml create mode 100644 config/STIDGCN/PEMSD7.yaml create mode 100644 config/STIDGCN/PEMSD8.yaml create mode 100644 model/STIDGCN/STIDGCN.py diff --git a/config/STIDGCN/PEMSD3.yaml b/config/STIDGCN/PEMSD3.yaml new file mode 100644 index 0000000..c3d8ce2 --- /dev/null +++ b/config/STIDGCN/PEMSD3.yaml @@ -0,0 +1,48 @@ +data: + num_nodes: 358 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 3 + output_dim: 1 + history: 12 + horizon: 12 + granularity: 288 + dropout: 0.1 + channels: 32 + + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 200 + plot: False diff --git a/config/STIDGCN/PEMSD4.yaml b/config/STIDGCN/PEMSD4.yaml new file mode 100644 index 0000000..fb4dae6 --- /dev/null +++ b/config/STIDGCN/PEMSD4.yaml @@ -0,0 +1,48 @@ +data: + num_nodes: 307 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 3 + output_dim: 1 + history: 12 + horizon: 12 + granularity: 288 + dropout: 0.1 + channels: 64 + + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 200 + plot: False diff --git a/config/STIDGCN/PEMSD7.yaml b/config/STIDGCN/PEMSD7.yaml new file mode 100644 index 0000000..4f67ce1 --- /dev/null +++ b/config/STIDGCN/PEMSD7.yaml @@ -0,0 +1,48 @@ +data: + num_nodes: 883 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 3 + output_dim: 1 + history: 12 + horizon: 12 + granularity: 288 + dropout: 0.1 + channels: 128 + + +train: + loss_func: mae + seed: 10 + batch_size: 16 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 200 + plot: False diff --git a/config/STIDGCN/PEMSD8.yaml b/config/STIDGCN/PEMSD8.yaml new file mode 100644 index 0000000..354c9e3 --- /dev/null +++ b/config/STIDGCN/PEMSD8.yaml @@ -0,0 +1,48 @@ +data: + num_nodes: 170 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 3 + output_dim: 1 + history: 12 + horizon: 12 + granularity: 288 + dropout: 0.1 + channels: 96 + + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 200 + plot: False diff --git a/model/STIDGCN/STIDGCN.py b/model/STIDGCN/STIDGCN.py new file mode 100644 index 0000000..8c89fb9 --- /dev/null +++ b/model/STIDGCN/STIDGCN.py @@ -0,0 +1,368 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +class GLU(nn.Module): + def __init__(self, features, dropout=0.1): + super(GLU, self).__init__() + self.conv1 = nn.Conv2d(features, features, (1, 1)) + self.conv2 = nn.Conv2d(features, features, (1, 1)) + self.conv3 = nn.Conv2d(features, features, (1, 1)) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x) + out = x1 * torch.sigmoid(x2) + out = self.dropout(out) + out = self.conv3(out) + return out + + +# class TemporalEmbedding(nn.Module): +# def __init__(self, time, features): +# super(TemporalEmbedding, self).__init__() +# +# self.time = time +# # self.time_day = nn.Parameter(torch.empty(time, features)) +# # nn.init.xavier_uniform_(self.time_day) +# # +# # self.time_week = nn.Parameter(torch.empty(7, features)) +# # nn.init.xavier_uniform_(self.time_week) +# self.time_day = nn.Embedding(time, features) +# self.time_week = nn.Embedding(7, features) +# +# def forward(self, x): +# day_emb = x[..., 1] +# # time_day = self.time_day[(day_emb[:, :, :] * self.time).type(torch.LongTensor)] +# # time_day = time_day.transpose(1, 2).contiguous() +# +# week_emb = x[..., 2] +# # time_week = self.time_week[(week_emb[:, :, :]).type(torch.LongTensor)] +# # time_week = time_week.transpose(1, 2).contiguous() +# +# t_idx = (day_emb[:, -1, :, ] * (self.time - 1)).long() # (B, N) +# d_idx = week_emb[:, -1, :, ].long() # (B, N) +# # time_emb = self.time_embedding(t_idx) # (B, N, hidden_dim) +# # day_emb = self.day_embedding(d_idx) # (B, N, hidden_dim) +# +# tem_emb = t_idx + d_idx +# +# # tem_emb = tem_emb.permute(0, 3, 1, 2) +# +# return tem_emb +class TemporalEmbedding(nn.Module): + def __init__(self, time, features): + super(TemporalEmbedding, self).__init__() + + self.time = time + self.time_day = nn.Parameter(torch.empty(time, features)) + nn.init.xavier_uniform_(self.time_day) + + self.time_week = nn.Parameter(torch.empty(7, features)) + nn.init.xavier_uniform_(self.time_week) + + def forward(self, x): + day_emb = x[..., 1] + time_day = self.time_day[(day_emb[:, :, :] * self.time).type(torch.LongTensor)] + time_day = time_day.transpose(1, 2).contiguous() + + week_emb = x[..., 2] + time_week = self.time_week[(week_emb[:, :, :]).type(torch.LongTensor)] + time_week = time_week.transpose(1, 2).contiguous() + + tem_emb = time_day + time_week + + tem_emb = tem_emb.permute(0,3,1,2) + + return tem_emb + +class Diffusion_GCN(nn.Module): + def __init__(self, channels=128, diffusion_step=1, dropout=0.1): + super().__init__() + self.diffusion_step = diffusion_step + self.conv = nn.Conv2d(diffusion_step * channels, channels, (1, 1)) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, adj): + out = [] + for i in range(0, self.diffusion_step): + if adj.dim() == 3: + x = torch.einsum("bcnt,bnm->bcmt", x, adj).contiguous() + out.append(x) + elif adj.dim() == 2: + x = torch.einsum("bcnt,nm->bcmt", x, adj).contiguous() + out.append(x) + x = torch.cat(out, dim=1) + x = self.conv(x) + output = self.dropout(x) + return output + + +class Graph_Generator(nn.Module): + def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1): + super().__init__() + self.memory = nn.Parameter(torch.randn(channels, num_nodes)) + nn.init.xavier_uniform_(self.memory) + self.fc = nn.Linear(2, 1) + + def forward(self, x): + adj_dyn_1 = torch.softmax( + F.relu( + torch.einsum("bcnt, cm->bnm", x, self.memory).contiguous() + / math.sqrt(x.shape[1]) + ), + -1, + ) + adj_dyn_2 = torch.softmax( + F.relu( + torch.einsum("bcn, bcm->bnm", x.sum(-1), x.sum(-1)).contiguous() + / math.sqrt(x.shape[1]) + ), + -1, + ) + # adj_dyn = (adj_dyn_1 + adj_dyn_2 + adj)/2 + adj_f = torch.cat([(adj_dyn_1).unsqueeze(-1)] + [(adj_dyn_2).unsqueeze(-1)], dim=-1) + adj_f = torch.softmax(self.fc(adj_f).squeeze(), -1) + + topk_values, topk_indices = torch.topk(adj_f, k=int(adj_f.shape[1] * 0.8), dim=-1) + mask = torch.zeros_like(adj_f) + mask.scatter_(-1, topk_indices, 1) + adj_f = adj_f * mask + + return adj_f + + +class DGCN(nn.Module): + def __init__(self, channels=128, num_nodes=170, diffusion_step=1, dropout=0.1, emb=None): + super().__init__() + self.conv = nn.Conv2d(channels, channels, (1, 1)) + self.generator = Graph_Generator(channels, num_nodes, diffusion_step, dropout) + self.gcn = Diffusion_GCN(channels, diffusion_step, dropout) + self.emb = emb + + def forward(self, x): + skip = x + x = self.conv(x) + adj_dyn = self.generator(x) + x = self.gcn(x, adj_dyn) + x = x * self.emb + skip + return x + + +class Splitting(nn.Module): + def __init__(self): + super(Splitting, self).__init__() + + def even(self, x): + return x[:, :, :, ::2] + + def odd(self, x): + return x[:, :, :, 1::2] + + def forward(self, x): + return (self.even(x), self.odd(x)) + + +class IDGCN(nn.Module): + def __init__( + self, + device, + channels=64, + diffusion_step=1, + splitting=True, + num_nodes=170, + dropout=0.2, emb=None + ): + super(IDGCN, self).__init__() + + device = device + self.dropout = dropout + self.num_nodes = num_nodes + self.splitting = splitting + self.split = Splitting() + + Conv1 = [] + Conv2 = [] + Conv3 = [] + Conv4 = [] + pad_l = 3 + pad_r = 3 + + k1 = 5 + k2 = 3 + Conv1 += [ + nn.ReplicationPad2d((pad_l, pad_r, 0, 0)), + nn.Conv2d(channels, channels, kernel_size=(1, k1)), + nn.LeakyReLU(negative_slope=0.01, inplace=True), + nn.Dropout(self.dropout), + nn.Conv2d(channels, channels, kernel_size=(1, k2)), + nn.Tanh(), + ] + Conv2 += [ + nn.ReplicationPad2d((pad_l, pad_r, 0, 0)), + nn.Conv2d(channels, channels, kernel_size=(1, k1)), + nn.LeakyReLU(negative_slope=0.01, inplace=True), + nn.Dropout(self.dropout), + nn.Conv2d(channels, channels, kernel_size=(1, k2)), + nn.Tanh(), + ] + Conv4 += [ + nn.ReplicationPad2d((pad_l, pad_r, 0, 0)), + nn.Conv2d(channels, channels, kernel_size=(1, k1)), + nn.LeakyReLU(negative_slope=0.01, inplace=True), + nn.Dropout(self.dropout), + nn.Conv2d(channels, channels, kernel_size=(1, k2)), + nn.Tanh(), + ] + Conv3 += [ + nn.ReplicationPad2d((pad_l, pad_r, 0, 0)), + nn.Conv2d(channels, channels, kernel_size=(1, k1)), + nn.LeakyReLU(negative_slope=0.01, inplace=True), + nn.Dropout(self.dropout), + nn.Conv2d(channels, channels, kernel_size=(1, k2)), + nn.Tanh(), + ] + + self.conv1 = nn.Sequential(*Conv1) + self.conv2 = nn.Sequential(*Conv2) + self.conv3 = nn.Sequential(*Conv3) + self.conv4 = nn.Sequential(*Conv4) + + self.dgcn = DGCN(channels, num_nodes, diffusion_step, dropout, emb) + + def forward(self, x): + if self.splitting: + (x_even, x_odd) = self.split(x) + else: + (x_even, x_odd) = x + + x1 = self.conv1(x_even) + x1 = self.dgcn(x1) + d = x_odd.mul(torch.tanh(x1)) + + x2 = self.conv2(x_odd) + x2 = self.dgcn(x2) + c = x_even.mul(torch.tanh(x2)) + + x3 = self.conv3(c) + x3 = self.dgcn(x3) + x_odd_update = d + x3 + + x4 = self.conv4(d) + x4 = self.dgcn(x4) + x_even_update = c + x4 + + return (x_even_update, x_odd_update) + + +class IDGCN_Tree(nn.Module): + def __init__( + self, device, channels=64, diffusion_step=1, num_nodes=170, dropout=0.1 + ): + super().__init__() + + self.memory1 = nn.Parameter(torch.randn(channels, num_nodes, 6)) + self.memory2 = nn.Parameter(torch.randn(channels, num_nodes, 3)) + self.memory3 = nn.Parameter(torch.randn(channels, num_nodes, 3)) + + self.IDGCN1 = IDGCN( + device=device, + splitting=True, + channels=channels, + diffusion_step=diffusion_step, + num_nodes=num_nodes, + dropout=dropout, emb=self.memory1 + ) + self.IDGCN2 = IDGCN( + device=device, + splitting=True, + channels=channels, + diffusion_step=diffusion_step, + num_nodes=num_nodes, + dropout=dropout, emb=self.memory2 + ) + self.IDGCN3 = IDGCN( + device=device, + splitting=True, + channels=channels, + diffusion_step=diffusion_step, + num_nodes=num_nodes, + dropout=dropout, emb=self.memory2 + ) + + def concat(self, even, odd): + even = even.permute(3, 1, 2, 0) + odd = odd.permute(3, 1, 2, 0) + len = even.shape[0] + _ = [] + for i in range(len): + _.append(even[i].unsqueeze(0)) + _.append(odd[i].unsqueeze(0)) + return torch.cat(_, 0).permute(3, 1, 2, 0) + + def forward(self, x): + x_even_update1, x_odd_update1 = self.IDGCN1(x) + x_even_update2, x_odd_update2 = self.IDGCN2(x_even_update1) + x_even_update3, x_odd_update3 = self.IDGCN3(x_odd_update1) + concat1 = self.concat(x_even_update2, x_odd_update2) + concat2 = self.concat(x_even_update3, x_odd_update3) + concat0 = self.concat(concat1, concat2) + output = concat0 + x + return output + + +class STIDGCN(nn.Module): + def __init__(self, args): + """ + device, input_dim, num_nodes, channels, granularity, dropout=0.1 + """ + super().__init__() + + device = args['device'] + input_dim = args['input_dim'] + self.num_nodes = args['num_nodes'] + self.output_len = 12 + channels = args['channels'] + granularity = args['granularity'] + dropout = args['dropout'] + diffusion_step = 1 + + self.Temb = TemporalEmbedding(granularity, channels) + + self.start_conv = nn.Conv2d( + in_channels=input_dim, out_channels=channels, kernel_size=(1, 1) + ) + + self.tree = IDGCN_Tree( + device=device, + channels=channels * 2, + diffusion_step=diffusion_step, + num_nodes=self.num_nodes, + dropout=dropout, + ) + + self.glu = GLU(channels * 2, dropout) + + self.regression_layer = nn.Conv2d( + channels * 2, self.output_len, kernel_size=(1, self.output_len) + ) + + def param_num(self): + return sum([param.nelement() for param in self.parameters()]) + + def forward(self, input): + input = input.transpose(1, 3) + x = input + # Encoder + # Data Embedding + time_emb = self.Temb(input.permute(0, 3, 2, 1)) + x = torch.cat([self.start_conv(x)] + [time_emb], dim=1) + # IDGCN_Tree + x = self.tree(x) + # Decoder + gcn = self.glu(x) + x + prediction = self.regression_layer(F.relu(gcn)) + return prediction \ No newline at end of file diff --git a/model/model_selector.py b/model/model_selector.py index 54b8c5f..9067833 100644 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -13,8 +13,8 @@ from model.STFGNN.STFGNN import STFGNN from model.STSGCN.STSGCN import STSGCN from model.STGODE.STGODE import ODEGCN from model.PDG2SEQ.PDG2Seq import PDG2Seq -from model.EXP.EXP import EXP -from model.EXPB.EXP_b import EXPB +from model.STIDGCN.STIDGCN import STIDGCN + def model_selector(model): match model['type']: @@ -33,6 +33,7 @@ def model_selector(model): case 'STSGCN': return STSGCN(model) case 'STGODE': return ODEGCN(model) case 'PDG2SEQ': return PDG2Seq(model) - case 'EXP': return EXP(model) - case 'EXPB': return EXPB(model) + case 'STIDGCN': return STIDGCN(model) + # case 'EXP': return EXP(model) + # case 'EXPB': return EXPB(model) From 5f8c31af2e5a58999b949af1f2ba6394afad3da2 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 23 Apr 2025 23:24:43 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E6=B7=BB=E5=8A=A0STIDGCN?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/model_selector.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/model/model_selector.py b/model/model_selector.py index 0117619..091a465 100644 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -14,6 +14,7 @@ from model.STSGCN.STSGCN import STSGCN from model.STGODE.STGODE import ODEGCN from model.PDG2SEQ.PDG2Seq import PDG2Seq from model.STMLP.STMLP import STMLP +from model.STIDGCN.STIDGCN import STIDGCN def model_selector(model): match model['type']: @@ -33,4 +34,5 @@ def model_selector(model): case 'STGODE': return ODEGCN(model) case 'PDG2SEQ': return PDG2Seq(model) case 'STMLP': return STMLP(model) + case 'STIDGCN': return STIDGCN(model) From c15cf605be7b319d12f24731c40b14c2099ec00c Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Sat, 26 Apr 2025 15:37:40 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E6=B7=BB=E5=8A=A0md=E7=BB=93=E6=9E=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Result.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 Result.md diff --git a/Result.md b/Result.md new file mode 100644 index 0000000..935fb88 --- /dev/null +++ b/Result.md @@ -0,0 +1,19 @@ +| NO. | Baselines | PEMS03 MAE | PEMS03 RMSE | PEMS03 MAPE | PEMS04 MAE | PEMS04 RMSE | PEMS04 MAPE | PEMS07 MAE | PEMS07 RMSE | PEMS07 MAPE | PEMS08 MAE | PEMS08 RMSE | PEMS08 MAPE | 备注 | +|-----|----------------|------------|-------------|-------------|------------|-------------|-------------|------------|-------------|-------------|------------|-------------|-------------|--------| +| 1 | HA | | | | | | | | | | | | | 未实现 | +| 2 | ARIMA | 30.99 | 48.28 | 28.66% | 39.7 | 59.12 | 27.57% | / | / | / | 32.51 | 48.5 | 19.94% | 偏高 | +| 3 | VAR | | | | | | | | | | | | | 未实现 | +| 4 | FC-LSTM | | | | | | | | | | | | | 未实现 | +| 5 | TCN | 29.51 | 45.79 | 29.11% | 37.6 | 55.5 | 26.81% | 42.6 | 62.19 | 20.22% | 31.18 | 45.8 | 20.64% | 偏高 | +| 6 | GRU-ED | | | | | | | | | | | | | 未实现 | +| 7 | DSANET | 21.26 | 34.44 | 21.18% | 27.77 | 43.89 | 18.88% | 31.59 | 49.42 | 13.93% | 22.38 | 35.48 | 14.26% | 合理 | +| 8 | STGCN | 17.41 | 29.31 | 18.91% | 20.58 | 32.7 | 14.75% | 23.17 | 36.73 | 10.54% | 18.05 | 27.69 | 13.67% | 合理 | +| 9 | DCRNN | 39.62 | 64.18 | 64.05% | 44.14 | 64.21 | 44.59% | 52.78 | 82.99 | 43.32% | 45.27 | 69.25 | 52.85% | 偏高 | +| 10 | GraphWaveNet | 14.68 | 25.86 | 14.38% | 19.19 | 31.04 | 13.06% | 20.40 | 33.48 | 8.73% | 14.83 | 23.86 | 10.14% | 偏低 | +| 11 | STSGCN | 18.41 | 30.77 | 19.28% | 21.4 | 35.04 | 14.28% | 24.47 | 38.96 | 10.77% | 17.58 | 27.19 | 12.00% | 合理 | +| 12 | AGCRN | 15.21 | 26.52 | 14.71% | 19.28 | 31.35 | 12.98% | 20.46 | 33.79 | 8.70% | 15.76 | 25.23 | 10.25% | 合理 | +| 13 | STFGNN | 17.29 | 29.56 | 17.48% | 23.06 | 36.23 | 15.52% | 24.67 | 38.93 | 10.89% | 16.87 | 27.48 | 11.16% | 合理 | +| 14 | STGODE | 16.55 | 26.62 | 17.58% | 22.55 | 35.05 | 15.91% | 23.28 | 26.19 | 10.97% | 17.22 | 26.66 | 11.52% | 合理 | +| 15 | STG-NCDE | 16.09 | 26.78 | 16.58% | 19.82 | 31.71 | 13.21% | 22.54 | 35.44 | 9.85% | 15.85 | 25.05 | 10.19% | 合理 | +| 16 | DDGCRN | 14.51 | 24.83 | 14.51% | 18.34 | 30.77 | 12.17% | 19.68 | 33.40 | 8.23% | 14.39 | 23.75 | 9.42% | 偏低 | +| 17 | TWDGCN | 14.65 | 24.84 | 14.66% | 18.54 | 30.53 | 12.29% | 20.01 | 33.62 | 8.50% | 14.65 | 24.19 | 9.51% | 偏高 |