diff --git a/.gitignore b/.gitignore index a68924d..d67c4d1 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ experiments/ *.pkl data/ pretrain/ +pre-train/ # ---> Python # Byte-compiled / optimized / DLL files diff --git a/config/STMLP/PEMSD3.yaml b/config/STMLP/PEMSD3.yaml new file mode 100644 index 0000000..eee7a15 --- /dev/null +++ b/config/STMLP/PEMSD3.yaml @@ -0,0 +1,66 @@ +data: + num_nodes: 358 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 1 + output_dim: 1 + input_window: 12 + output_window: 12 + gcn_true: true + buildA_true: true + gcn_depth: 2 + dropout: 0.3 + subgraph_size: 20 + node_dim: 40 + dilation_exponential: 1 + conv_channels: 32 + residual_channels: 32 + skip_channels: 64 + end_channels: 128 + layers: 3 + propalpha: 0.05 + tanhalpha: 3 + layer_norm_affline: true + use_curriculum_learning: true + step_size1: 2500 + task_level: 0 + num_split: 1 + step_size2: 100 + model_type: stmlp + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + teacher_stu: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 2000 + plot: False diff --git a/config/STMLP/PEMSD4.yaml b/config/STMLP/PEMSD4.yaml new file mode 100644 index 0000000..c416fc4 --- /dev/null +++ b/config/STMLP/PEMSD4.yaml @@ -0,0 +1,67 @@ +data: + num_nodes: 307 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 1 + output_dim: 1 + input_window: 12 + output_window: 12 + gcn_true: true + buildA_true: true + gcn_depth: 2 + dropout: 0.3 + subgraph_size: 20 + node_dim: 40 + dilation_exponential: 1 + conv_channels: 32 + residual_channels: 32 + skip_channels: 64 + end_channels: 128 + layers: 3 + propalpha: 0.05 + tanhalpha: 3 + layer_norm_affline: true + use_curriculum_learning: true + step_size1: 2500 + task_level: 0 + num_split: 1 + step_size2: 100 + model_type: stmlp + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + teacher: True + teacher_stu: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 2000 + plot: False diff --git a/config/STMLP/PEMSD7.yaml b/config/STMLP/PEMSD7.yaml new file mode 100644 index 0000000..14e6382 --- /dev/null +++ b/config/STMLP/PEMSD7.yaml @@ -0,0 +1,66 @@ +data: + num_nodes: 883 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 1 + output_dim: 1 + input_window: 12 + output_window: 12 + gcn_true: true + buildA_true: true + gcn_depth: 2 + dropout: 0.3 + subgraph_size: 20 + node_dim: 40 + dilation_exponential: 1 + conv_channels: 32 + residual_channels: 32 + skip_channels: 64 + end_channels: 128 + layers: 3 + propalpha: 0.05 + tanhalpha: 3 + layer_norm_affline: true + use_curriculum_learning: true + step_size1: 2500 + task_level: 0 + num_split: 1 + step_size2: 100 + model_type: stmlp + +train: + loss_func: mae + seed: 10 + batch_size: 16 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + teacher_stu: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 2000 + plot: False diff --git a/config/STMLP/PEMSD8.yaml b/config/STMLP/PEMSD8.yaml new file mode 100644 index 0000000..bceffa5 --- /dev/null +++ b/config/STMLP/PEMSD8.yaml @@ -0,0 +1,66 @@ +data: + num_nodes: 170 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 1 + output_dim: 1 + input_window: 12 + output_window: 12 + gcn_true: true + buildA_true: true + gcn_depth: 2 + dropout: 0.3 + subgraph_size: 20 + node_dim: 40 + dilation_exponential: 1 + conv_channels: 32 + residual_channels: 32 + skip_channels: 64 + end_channels: 128 + layers: 3 + propalpha: 0.05 + tanhalpha: 3 + layer_norm_affline: true + use_curriculum_learning: true + step_size1: 2500 + task_level: 0 + num_split: 1 + step_size2: 100 + model_type: stmlp + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + teacher_stu: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 2000 + plot: False diff --git a/lib/Download_data.py b/lib/Download_data.py index 9cc0006..ed7c929 100644 --- a/lib/Download_data.py +++ b/lib/Download_data.py @@ -121,7 +121,7 @@ def download_kaggle_data(current_dir): 如果目标文件夹已存在,会覆盖冲突的文件。 """ try: - print("正在下载 KaggleHub 数据集...") + print("正在下载 PEMS 数据集...") path = kagglehub.dataset_download("elmahy/pems-dataset") # print("Path to KaggleHub dataset files:", path) diff --git a/model/STMLP/STMLP.py b/model/STMLP/STMLP.py new file mode 100644 index 0000000..8af1134 --- /dev/null +++ b/model/STMLP/STMLP.py @@ -0,0 +1,307 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init +from data.get_adj import get_adj +import numbers + + +# --- 基础算子 --- +class NConv(nn.Module): + def forward(self, x, adj): + return torch.einsum('ncwl,vw->ncvl', (x, adj)).contiguous() + + +class DyNconv(nn.Module): + def forward(self, x, adj): + return torch.einsum('ncvl,nvwl->ncwl', (x, adj)).contiguous() + + +class Linear(nn.Module): + def __init__(self, c_in, c_out, bias=True): + super().__init__() + self.mlp = nn.Conv2d(c_in, c_out, kernel_size=1, bias=bias) + + def forward(self, x): + return self.mlp(x) + + +class Prop(nn.Module): + def __init__(self, c_in, c_out, gdep, dropout, alpha): + super().__init__() + self.nconv = NConv() + self.mlp = Linear(c_in, c_out) + self.gdep, self.dropout, self.alpha = gdep, dropout, alpha + + def forward(self, x, adj): + adj = adj + torch.eye(adj.size(0), device=x.device) + d = adj.sum(1) + a = adj / d.view(-1, 1) + h = x + for _ in range(self.gdep): + h = self.alpha * x + (1 - self.alpha) * self.nconv(h, a) + return self.mlp(h) + + +class MixProp(nn.Module): + def __init__(self, c_in, c_out, gdep, dropout, alpha): + super().__init__() + self.nconv = NConv() + self.mlp = Linear((gdep + 1) * c_in, c_out) + self.gdep, self.dropout, self.alpha = gdep, dropout, alpha + + def forward(self, x, adj): + adj = adj + torch.eye(adj.size(0), device=x.device) + d = adj.sum(1) + a = adj / d.view(-1, 1) + out = [x] + h = x + for _ in range(self.gdep): + h = self.alpha * x + (1 - self.alpha) * self.nconv(h, a) + out.append(h) + return self.mlp(torch.cat(out, dim=1)) + + +class DyMixprop(nn.Module): + def __init__(self, c_in, c_out, gdep, dropout, alpha): + super().__init__() + self.nconv = DyNconv() + self.mlp1 = Linear((gdep + 1) * c_in, c_out) + self.mlp2 = Linear((gdep + 1) * c_in, c_out) + self.gdep, self.dropout, self.alpha = gdep, dropout, alpha + self.lin1, self.lin2 = Linear(c_in, c_in), Linear(c_in, c_in) + + def forward(self, x): + x1 = torch.tanh(self.lin1(x)) + x2 = torch.tanh(self.lin2(x)) + adj = self.nconv(x1.transpose(2, 1), x2) + adj0 = torch.softmax(adj, dim=2) + adj1 = torch.softmax(adj.transpose(2, 1), dim=2) + # 两条分支 + out1, out2 = [x], [x] + h = x + for _ in range(self.gdep): + h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj0) + out1.append(h) + h = x + for _ in range(self.gdep): + h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj1) + out2.append(h) + return self.mlp1(torch.cat(out1, dim=1)) + self.mlp2(torch.cat(out2, dim=1)) + + +class DilatedInception(nn.Module): + def __init__(self, cin, cout, dilation_factor=2): + super().__init__() + self.kernels = [2, 3, 6, 7] + cout_each = int(cout / len(self.kernels)) + self.convs = nn.ModuleList([nn.Conv2d(cin, cout_each, kernel_size=(1, k), dilation=(1, dilation_factor)) + for k in self.kernels]) + + def forward(self, x): + outs = [conv(x)[..., -self.convs[-1](x).size(3):] for conv in self.convs] + return torch.cat(outs, dim=1) + + +class GraphConstructor(nn.Module): + def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): + super().__init__() + self.nnodes, self.k, self.dim, self.alpha, self.device = nnodes, k, dim, alpha, device + self.static_feat = static_feat + if static_feat is not None: + xd = static_feat.shape[1] + self.lin1, self.lin2 = nn.Linear(xd, dim), nn.Linear(xd, dim) + else: + self.emb1 = nn.Embedding(nnodes, dim) + self.emb2 = nn.Embedding(nnodes, dim) + self.lin1, self.lin2 = nn.Linear(dim, dim), nn.Linear(dim, dim) + + def forward(self, idx): + if self.static_feat is None: + vec1, vec2 = self.emb1(idx), self.emb2(idx) + else: + vec1 = vec2 = self.static_feat[idx, :] + vec1 = torch.tanh(self.alpha * self.lin1(vec1)) + vec2 = torch.tanh(self.alpha * self.lin2(vec2)) + a = torch.mm(vec1, vec2.transpose(1, 0)) - torch.mm(vec2, vec1.transpose(1, 0)) + adj = F.relu(torch.tanh(self.alpha * a)) + mask = torch.zeros(idx.size(0), idx.size(0), device=self.device) + s1, t1 = adj.topk(self.k, 1) + mask.scatter_(1, t1, s1.new_ones(s1.size())) + return adj * mask + + +class LayerNorm(nn.Module): + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] + + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super().__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape, self.eps, self.elementwise_affine = tuple(normalized_shape), eps, elementwise_affine + if elementwise_affine: + self.weight = nn.Parameter(torch.Tensor(*normalized_shape)) + self.bias = nn.Parameter(torch.Tensor(*normalized_shape)) + init.ones_(self.weight); + init.zeros_(self.bias) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + + def forward(self, x, idx): + if self.elementwise_affine: + return F.layer_norm(x, tuple(x.shape[1:]), self.weight[:, idx, :], self.bias[:, idx, :], self.eps) + else: + return F.layer_norm(x, tuple(x.shape[1:]), self.weight, self.bias, self.eps) + + def extra_repr(self): + return f'{self.normalized_shape}, eps={self.eps}, elementwise_affine={self.elementwise_affine}' + + +# --- 合并后的模型类,支持 teacher 与 stmlp 两种分支 --- +class STMLP(nn.Module): + def __init__(self, args): + super().__init__() + # 参数从字典中读取 + self.adj_mx = get_adj(args) + self.num_nodes = args['num_nodes'] + self.feature_dim = args['input_dim'] + + self.input_window = args['input_window'] + self.output_window = args['output_window'] + self.output_dim = args['output_dim'] + self.device = args['device'] + + self.gcn_true = args['gcn_true'] + self.buildA_true = args['buildA_true'] + self.gcn_depth = args['gcn_depth'] + self.dropout = args['dropout'] + self.subgraph_size = args['subgraph_size'] + self.node_dim = args['node_dim'] + self.dilation_exponential = args['dilation_exponential'] + + self.conv_channels = args['conv_channels'] + self.residual_channels = args['residual_channels'] + self.skip_channels = args['skip_channels'] + self.end_channels = args['end_channels'] + + self.layers = args['layers'] + self.propalpha = args['propalpha'] + self.tanhalpha = args['tanhalpha'] + self.layer_norm_affline = args['layer_norm_affline'] + + self.model_type = args['model_type'] # 'teacher' 或 'stmlp' + self.idx = torch.arange(self.num_nodes).to(self.device) + self.predefined_A = None if self.adj_mx is None else (torch.tensor(self.adj_mx) - torch.eye(self.num_nodes)).to( + self.device) + self.static_feat = None + + # transformer(保留原有结构) + self.encoder_layer = nn.TransformerEncoderLayer(d_model=12, nhead=4, batch_first=True) + self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=3) + + # 构建各层 + self.start_conv = nn.Conv2d(self.feature_dim, self.residual_channels, kernel_size=1) + self.gc = GraphConstructor(self.num_nodes, self.subgraph_size, self.node_dim, self.device, alpha=self.tanhalpha, + static_feat=self.static_feat) + # 计算 receptive_field + kernel_size = 7 + if self.dilation_exponential > 1: + self.receptive_field = int( + self.output_dim + (kernel_size - 1) * (self.dilation_exponential ** self.layers - 1) / ( + self.dilation_exponential - 1)) + else: + self.receptive_field = self.layers * (kernel_size - 1) + self.output_dim + + self.filter_convs = nn.ModuleList() + self.gate_convs = nn.ModuleList() + self.residual_convs = nn.ModuleList() + self.skip_convs = nn.ModuleList() + self.norm = nn.ModuleList() + self.stu_mlp = nn.ModuleList([nn.Sequential(nn.Linear(c, c), nn.Linear(c, c), nn.Linear(c, c)) + for c in [13, 7, 1]]) + if self.gcn_true: + self.gconv1 = nn.ModuleList() + self.gconv2 = nn.ModuleList() + + new_dilation = 1 + for i in range(1): + rf_size_i = int(1 + i * (kernel_size - 1) * (self.dilation_exponential ** self.layers - 1) / ( + self.dilation_exponential - 1)) if self.dilation_exponential > 1 else i * self.layers * ( + kernel_size - 1) + 1 + for j in range(1, self.layers + 1): + rf_size_j = int(rf_size_i + (kernel_size - 1) * (self.dilation_exponential ** j - 1) / ( + self.dilation_exponential - 1)) if self.dilation_exponential > 1 else rf_size_i + j * ( + kernel_size - 1) + self.filter_convs.append( + DilatedInception(self.residual_channels, self.conv_channels, dilation_factor=new_dilation)) + self.gate_convs.append( + DilatedInception(self.residual_channels, self.conv_channels, dilation_factor=new_dilation)) + self.residual_convs.append(nn.Conv2d(self.conv_channels, self.residual_channels, kernel_size=1)) + k_size = (1, self.input_window - rf_size_j + 1) if self.input_window > self.receptive_field else ( + 1, self.receptive_field - rf_size_j + 1) + self.skip_convs.append(nn.Conv2d(self.conv_channels, self.skip_channels, kernel_size=k_size)) + if self.gcn_true: + self.gconv1.append(MixProp(self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout, + self.propalpha)) + self.gconv2.append(MixProp(self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout, + self.propalpha)) + norm_size = (self.residual_channels, self.num_nodes, + self.input_window - rf_size_j + 1) if self.input_window > self.receptive_field else ( + self.residual_channels, self.num_nodes, self.receptive_field - rf_size_j + 1) + self.norm.append(LayerNorm(norm_size, elementwise_affine=self.layer_norm_affline)) + new_dilation *= self.dilation_exponential + + self.end_conv_1 = nn.Conv2d(self.skip_channels, self.end_channels, kernel_size=1, bias=True) + self.end_conv_2 = nn.Conv2d(self.end_channels, self.output_window, kernel_size=1, bias=True) + k0 = (1, self.input_window) if self.input_window > self.receptive_field else (1, self.receptive_field) + self.skip0 = nn.Conv2d(self.feature_dim, self.skip_channels, kernel_size=k0, bias=True) + kE = (1, self.input_window - self.receptive_field + 1) if self.input_window > self.receptive_field else (1, 1) + self.skipE = nn.Conv2d(self.residual_channels, self.skip_channels, kernel_size=kE, bias=True) + # 最后输出分支,根据模型类型选择不同的头 + if self.model_type == 'teacher': + self.tt_linear1 = nn.Linear(self.residual_channels, self.input_window) + self.tt_linear2 = nn.Linear(1, 32) + self.ss_linear1 = nn.Linear(self.residual_channels, self.input_window) + self.ss_linear2 = nn.Linear(1, 32) + else: # stmlp + self.out_linear1 = nn.Linear(self.residual_channels, self.input_window) + self.out_linear2 = nn.Linear(1, 32) + + def forward(self, source, idx=None): + source = source[..., 0:1] + sout, tout = [], [] + inputs = source.transpose(1, 3) + assert inputs.size(3) == self.input_window, 'input sequence length mismatch' + if self.input_window < self.receptive_field: + inputs = F.pad(inputs, (self.receptive_field - self.input_window, 0, 0, 0)) + if self.gcn_true: + adp = self.gc(self.idx if idx is None else idx) if self.buildA_true else self.predefined_A + x = self.start_conv(inputs) + skip = self.skip0(F.dropout(inputs, self.dropout, training=self.training)) + for i in range(self.layers): + residual = x + filters = torch.tanh(self.filter_convs[i](x)) + gate = torch.sigmoid(self.gate_convs[i](x)) + x = F.dropout(filters * gate, self.dropout, training=self.training) + tout.append(x) + s = self.skip_convs[i](x) + skip = s + skip + if self.gcn_true: + x = self.gconv1[i](x, adp) + self.gconv2[i](x, adp.transpose(1, 0)) + else: + x = self.stu_mlp[i](x) + x = x + residual[:, :, :, -x.size(3):] + x = self.norm[i](x, self.idx if idx is None else idx) + sout.append(x) + skip = self.skipE(x) + skip + x = F.relu(skip) + x = F.relu(self.end_conv_1(x)) + x = self.end_conv_2(x) + if self.model_type == 'teacher': + ttout = self.tt_linear2(self.tt_linear1(tout[-1].transpose(1, 3)).transpose(1, 3)) + ssout = self.ss_linear2(self.ss_linear1(sout[-1].transpose(1, 3)).transpose(1, 3)) + return x, ttout, ssout + else: + x_ = self.out_linear2(self.out_linear1(tout[-1].transpose(1, 3)).transpose(1, 3)) + return x, x_, x diff --git a/model/model_selector.py b/model/model_selector.py index 54b8c5f..0117619 100644 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -13,8 +13,7 @@ from model.STFGNN.STFGNN import STFGNN from model.STSGCN.STSGCN import STSGCN from model.STGODE.STGODE import ODEGCN from model.PDG2SEQ.PDG2Seq import PDG2Seq -from model.EXP.EXP import EXP -from model.EXPB.EXP_b import EXPB +from model.STMLP.STMLP import STMLP def model_selector(model): match model['type']: @@ -33,6 +32,5 @@ def model_selector(model): case 'STSGCN': return STSGCN(model) case 'STGODE': return ODEGCN(model) case 'PDG2SEQ': return PDG2Seq(model) - case 'EXP': return EXP(model) - case 'EXPB': return EXPB(model) + case 'STMLP': return STMLP(model) diff --git a/run.py b/run.py index 7499708..082fe05 100644 --- a/run.py +++ b/run.py @@ -17,9 +17,6 @@ from dataloader.loader_selector import get_dataloader from trainer.trainer_selector import select_trainer import yaml - - - def main(): args = parse_args() diff --git a/trainer/DCRNN_Trainer.py b/trainer/DCRNN_Trainer.py index ecc4eb0..97a8290 100644 --- a/trainer/DCRNN_Trainer.py +++ b/trainer/DCRNN_Trainer.py @@ -160,10 +160,6 @@ class Trainer: y_pred = torch.cat(y_pred, dim=0) y_true = torch.cat(y_true, dim=0) - # 你在这里需要把y_pred和y_true保存下来 - # torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1] - # torch.save(y_true, "./test/PEMS08/y_true.pt") # [3566,12,170,1] - for t in range(y_true.shape[1]): mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], args['mae_thresh'], args['mape_thresh']) diff --git a/trainer/PDG2SEQ_Trainer.py b/trainer/PDG2SEQ_Trainer.py index 00750a1..bde4801 100644 --- a/trainer/PDG2SEQ_Trainer.py +++ b/trainer/PDG2SEQ_Trainer.py @@ -161,10 +161,6 @@ class Trainer: y_pred = torch.cat(y_pred, dim=0) y_true = torch.cat(y_true, dim=0) - # 你在这里需要把y_pred和y_true保存下来 - # torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1] - # torch.save(y_true, "./test/PEMS08/y_true.pt") # [3566,12,170,1] - for t in range(y_true.shape[1]): mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], args['mae_thresh'], args['mape_thresh']) diff --git a/trainer/STMLP_Trainer.py b/trainer/STMLP_Trainer.py new file mode 100644 index 0000000..6489221 --- /dev/null +++ b/trainer/STMLP_Trainer.py @@ -0,0 +1,261 @@ +import math +import os +import sys +import time +import copy +import torch.nn.functional as F +import torch +from torch import nn + +from tqdm import tqdm +from lib.logger import get_logger +from lib.loss_function import all_metrics +from model.STMLP.STMLP import STMLP + + +class Trainer: + def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, + scaler, args, lr_scheduler=None): + self.model = model + self.loss = loss + self.optimizer = optimizer + self.train_loader = train_loader + self.val_loader = val_loader + self.test_loader = test_loader + self.scaler = scaler + self.args = args['train'] + self.lr_scheduler = lr_scheduler + self.train_per_epoch = len(train_loader) + self.val_per_epoch = len(val_loader) if val_loader else 0 + + # Paths for saving models and logs + self.best_path = os.path.join(self.args['log_dir'], 'best_model.pth') + self.best_test_path = os.path.join(self.args['log_dir'], 'best_test_model.pth') + self.loss_figure_path = os.path.join(self.args['log_dir'], 'loss.png') + self.pretrain_dir = f'./pre-train/{args["model"]["type"]}/{args["data"]["type"]}' + self.pretrain_path = os.path.join(self.pretrain_dir, 'best_model.pth') + self.pretrain_best_path = os.path.join(self.pretrain_dir, 'best_test_model.pth') + + # Initialize logger + if not os.path.isdir(self.args['log_dir']) and not self.args['debug']: + os.makedirs(self.args['log_dir'], exist_ok=True) + if not os.path.isdir(self.pretrain_dir) and not self.args['debug']: + os.makedirs(self.pretrain_dir, exist_ok=True) + self.logger = get_logger(self.args['log_dir'], name=self.model.__class__.__name__, debug=self.args['debug']) + self.logger.info(f"Experiment log path in: {self.args['log_dir']}") + + if self.args['teacher_stu']: + self.tmodel = self.loadTeacher(args) + else: + self.logger.info(f"当前使用预训练模式,预训练后请移动教师模型到" + f"./pre-train/{args['model']['type']}/{args['data']['type']}/best_model.pth" + f"然后在config中配置train.teacher_stu模式为True开启蒸馏模式") + + + def _run_epoch(self, epoch, dataloader, mode): + # self.tmodel.eval() + if mode == 'train': + self.model.train() + optimizer_step = True + else: + self.model.eval() + optimizer_step = False + + total_loss = 0 + epoch_time = time.time() + + with torch.set_grad_enabled(optimizer_step): + with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar: + for batch_idx, (data, target) in enumerate(dataloader): + if self.args['teacher_stu']: + label = target[..., :self.args['output_dim']] + output, out_, _ = self.model(data) + gout, tout, sout = self.tmodel(data) + + if self.args['real_value']: + output = self.scaler.inverse_transform(output) + + loss1 = self.loss(output, label) + scl = self.loss_cls(out_, sout) + kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True).cuda() + gout = F.log_softmax(gout, dim=-1).cuda() + mlp_emb_ = F.log_softmax(output, dim=-1).cuda() + tkloss = kl_loss(mlp_emb_.cuda().float(), gout.cuda().float()) + loss = loss1 + 10 * tkloss + 1 * scl + + else: + label = target[..., :self.args['output_dim']] + output, out_, _ = self.model(data) + + if self.args['real_value']: + output = self.scaler.inverse_transform(output) + + loss = self.loss(output, label) + + if optimizer_step and self.optimizer is not None: + self.optimizer.zero_grad() + loss.backward() + + if self.args['grad_norm']: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm']) + self.optimizer.step() + + total_loss += loss.item() + + if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0: + self.logger.info( + f'Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {loss.item():.6f}') + + # 更新 tqdm 的进度 + pbar.update(1) + pbar.set_postfix(loss=loss.item()) + + avg_loss = total_loss / len(dataloader) + self.logger.info( + f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s') + return avg_loss + + def train_epoch(self, epoch): + return self._run_epoch(epoch, self.train_loader, 'train') + + def val_epoch(self, epoch): + return self._run_epoch(epoch, self.val_loader or self.test_loader, 'val') + + def test_epoch(self, epoch): + return self._run_epoch(epoch, self.test_loader, 'test') + + def train(self): + best_model, best_test_model = None, None + best_loss, best_test_loss = float('inf'), float('inf') + not_improved_count = 0 + + self.logger.info("Training process started") + for epoch in range(1, self.args['epochs'] + 1): + train_epoch_loss = self.train_epoch(epoch) + val_epoch_loss = self.val_epoch(epoch) + test_epoch_loss = self.test_epoch(epoch) + + if train_epoch_loss > 1e6: + self.logger.warning('Gradient explosion detected. Ending...') + break + + if val_epoch_loss < best_loss: + best_loss = val_epoch_loss + not_improved_count = 0 + best_model = copy.deepcopy(self.model.state_dict()) + torch.save(best_model, self.best_path) + torch.save(best_model, self.pretrain_path) + self.logger.info('Best validation model saved!') + else: + not_improved_count += 1 + + if self.args['early_stop'] and not_improved_count == self.args['early_stop_patience']: + self.logger.info( + f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops.") + break + + if test_epoch_loss < best_test_loss: + best_test_loss = test_epoch_loss + best_test_model = copy.deepcopy(self.model.state_dict()) + torch.save(best_test_model, self.best_test_path) + torch.save(best_model, self.pretrain_best_path) + + if not self.args['debug']: + torch.save(best_model, self.best_path) + torch.save(best_test_model, self.best_test_path) + self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}") + + self._finalize_training(best_model, best_test_model) + + def _finalize_training(self, best_model, best_test_model): + self.model.load_state_dict(best_model) + self.logger.info("Testing on best validation model") + self.test(self.model, self.args, self.test_loader, self.scaler, self.logger) + + self.model.load_state_dict(best_test_model) + self.logger.info("Testing on best test model") + self.test(self.model, self.args, self.test_loader, self.scaler, self.logger) + + def loadTeacher(self, args): + model_path = f'./pre-train/{args["model"]["type"]}/{args["data"]["type"]}/best_model.pth' + try: + # 尝试加载教师模型权重 + state_dict = torch.load(model_path) + self.logger.info(f"成功加载教师模型权重: {model_path}") + + # 初始化并返回教师模型 + args['model']['model_type'] = 'teacher' + tmodel = STMLP(args['model']) + tmodel = tmodel.to(args['device']) + tmodel.load_state_dict(state_dict, strict=False) + return tmodel + + except FileNotFoundError: + # 如果找不到权重文件,记录日志并修改 args + self.logger.error( + f"未找到教师模型权重文件: {model_path}。切换到预训练模式训练老师权重。\n" + f"在预训练完成后,再次启动模型则为蒸馏模式") + self.args['teacher_stu'] = False + return None + + + def loss_cls(self, x1, x2): + temperature = 0.05 + x1 = F.normalize(x1, p=2, dim=-1) + x2 = F.normalize(x2, p=2, dim=-1) + weight = F.cosine_similarity(x1, x2, dim=-1) + batch_size = x1.size()[0] + # neg score + out = torch.cat([x1, x2], dim=0) + neg = torch.exp(torch.matmul(out, out.transpose(2, 3).contiguous()) / temperature) + + pos = torch.exp(torch.sum(x1 * x2, dim=-1) * weight / temperature) + # pos = torch.exp(torch.sum(x1 * x2, dim=-1) / temperature) + pos = torch.cat([pos, pos], dim=0).sum(dim=1) + + Ng = neg.sum(dim=-1).sum(dim=1) + + loss = (- torch.log(pos / (pos + Ng))).mean() + + return loss + + @staticmethod + def test(model, args, data_loader, scaler, logger, path=None): + if path: + checkpoint = torch.load(path) + model.load_state_dict(checkpoint['state_dict']) + model.to(args['device']) + + model.eval() + y_pred, y_true = [], [] + + with torch.no_grad(): + for data, target in data_loader: + label = target[..., :args['output_dim']] + output, _, _ = model(data) + y_pred.append(output) + y_true.append(label) + + if args['real_value']: + y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) + else: + y_pred = torch.cat(y_pred, dim=0) + y_true = torch.cat(y_true, dim=0) + + # 你在这里需要把y_pred和y_true保存下来 + # torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1] + # torch.save(y_true, "./test/PEMSD8/y_true.pt") # [3566,12,170,1] + + for t in range(y_true.shape[1]): + mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], + args['mae_thresh'], args['mape_thresh']) + logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + + mae, rmse, mape = all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh']) + logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + + @staticmethod + def _compute_sampling_threshold(global_step, k): + return k / (k + math.exp(global_step / k)) + + diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 5613870..45539c1 100644 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -107,6 +107,7 @@ class Trainer: best_loss = val_epoch_loss not_improved_count = 0 best_model = copy.deepcopy(self.model.state_dict()) + torch.save(best_model, self.best_path) self.logger.info('Best validation model saved!') else: not_improved_count += 1 @@ -118,6 +119,7 @@ class Trainer: if test_epoch_loss < best_test_loss: best_test_loss = test_epoch_loss + torch.save(best_test_model, self.best_test_path) best_test_model = copy.deepcopy(self.model.state_dict()) if not self.args['debug']: @@ -161,7 +163,7 @@ class Trainer: # 你在这里需要把y_pred和y_true保存下来 # torch.save(y_pred, "./test/PEMS07/y_pred_D.pt") # [3566,12,170,1] - # torch.save(y_true, "./test/PEMS08/y_true.pt") # [3566,12,170,1] + # torch.save(y_true, "./test/PEMSD8/y_true.pt") # [3566,12,170,1] for t in range(y_true.shape[1]): mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], diff --git a/trainer/trainer_selector.py b/trainer/trainer_selector.py index eaad3ab..c52ca67 100644 --- a/trainer/trainer_selector.py +++ b/trainer/trainer_selector.py @@ -2,6 +2,7 @@ from trainer.Trainer import Trainer from trainer.cdeTrainer.cdetrainer import Trainer as cdeTrainer from trainer.DCRNN_Trainer import Trainer as DCRNN_Trainer from trainer.PDG2SEQ_Trainer import Trainer as PDG2SEQ_Trainer +from trainer.STMLP_Trainer import Trainer as STMLP_Trainer def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, @@ -13,5 +14,7 @@ def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader lr_scheduler) case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], lr_scheduler) + case 'STMLP': return STMLP_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, + lr_scheduler) case _: return Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], lr_scheduler) diff --git a/transfer_guide.md b/transfer_guide.md index 43b4349..8d4bde1 100644 --- a/transfer_guide.md +++ b/transfer_guide.md @@ -299,7 +299,7 @@ def read_data(args): 'pems03': ['PEMS03/pems03.npz', 'PEMS03/distance.csv'], 'pems04': ['PEMS04/pems04.npz', 'PEMS04/distance.csv'], 'pems07': ['PEMS07/pems07.npz', 'PEMS07/distance.csv'], - 'pems08': ['PEMS08/pems08.npz', 'PEMS08/distance.csv'], + 'pems08': ['PEMSD8/pems08.npz', 'PEMSD8/distance.csv'], 'pemsbay': ['PEMSBAY/pems_bay.npz', 'PEMSBAY/distance.csv'], 'pemsD7M': ['PeMSD7M/PeMSD7M.npz', 'PeMSD7M/distance.csv'], 'pemsD7L': ['PeMSD7L/PeMSD7L.npz', 'PeMSD7L/distance.csv']