From 4ccb029d7ee8948e1ff2533d1fe122a993457e8d Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 10 Dec 2025 21:53:46 +0800 Subject: [PATCH] impl PatchTST --- config/PatchTST/AirQuality.yaml | 54 ++++ config/PatchTST/BJTaxi-Inflow.yaml | 54 ++++ config/PatchTST/BJTaxi-Outflow.yaml | 54 ++++ config/PatchTST/METR-LA.yaml | 54 ++++ config/PatchTST/NYCBike-Inflow.yaml | 54 ++++ config/PatchTST/NYCBike-Outflow.yaml | 54 ++++ config/PatchTST/PEMS-BAY.yaml | 54 ++++ config/PatchTST/SolarEnergy.yaml | 54 ++++ dataloader/loader_selector.py | 2 +- model/MTGNN/MTGNN.py | 134 ++++++++++ model/MTGNN/layer.py | 328 +++++++++++++++++++++++++ model/PatchTST/PatchTST.py | 109 ++++++++ model/PatchTST/layers/Embed.py | 29 +++ model/PatchTST/layers/SelfAttention.py | 80 ++++++ model/PatchTST/layers/Transformer.py | 57 +++++ model/model_selector.py | 3 + train.py | 5 +- trainer/Trainer.py | 1 - 18 files changed, 1177 insertions(+), 3 deletions(-) create mode 100644 config/PatchTST/AirQuality.yaml create mode 100644 config/PatchTST/BJTaxi-Inflow.yaml create mode 100644 config/PatchTST/BJTaxi-Outflow.yaml create mode 100644 config/PatchTST/METR-LA.yaml create mode 100644 config/PatchTST/NYCBike-Inflow.yaml create mode 100644 config/PatchTST/NYCBike-Outflow.yaml create mode 100644 config/PatchTST/PEMS-BAY.yaml create mode 100644 config/PatchTST/SolarEnergy.yaml create mode 100644 model/MTGNN/MTGNN.py create mode 100644 model/MTGNN/layer.py create mode 100644 model/PatchTST/PatchTST.py create mode 100644 model/PatchTST/layers/Embed.py create mode 100644 model/PatchTST/layers/SelfAttention.py create mode 100644 model/PatchTST/layers/Transformer.py diff --git a/config/PatchTST/AirQuality.yaml b/config/PatchTST/AirQuality.yaml new file mode 100644 index 0000000..a3e6418 --- /dev/null +++ b/config/PatchTST/AirQuality.yaml @@ -0,0 +1,54 @@ +basic: + dataset: AirQuality + device: cuda:0 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 256 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + normalizer: std + num_nodes: 35 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + patch_len: 6 + stride: 8 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 256 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 6 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/PatchTST/BJTaxi-Inflow.yaml b/config/PatchTST/BJTaxi-Inflow.yaml new file mode 100644 index 0000000..9bd66d9 --- /dev/null +++ b/config/PatchTST/BJTaxi-Inflow.yaml @@ -0,0 +1,54 @@ +basic: + dataset: BJTaxi-InFlow + device: cuda:0 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 2048 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + patch_len: 6 + stride: 8 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 2048 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/PatchTST/BJTaxi-Outflow.yaml b/config/PatchTST/BJTaxi-Outflow.yaml new file mode 100644 index 0000000..2382695 --- /dev/null +++ b/config/PatchTST/BJTaxi-Outflow.yaml @@ -0,0 +1,54 @@ +basic: + dataset: BJTaxi-OutFlow + device: cuda:0 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 2048 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + patch_len: 6 + stride: 8 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 2048 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/PatchTST/METR-LA.yaml b/config/PatchTST/METR-LA.yaml new file mode 100644 index 0000000..d076d35 --- /dev/null +++ b/config/PatchTST/METR-LA.yaml @@ -0,0 +1,54 @@ +basic: + dataset: METR-LA + device: cuda:1 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 256 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 207 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + patch_len: 6 + stride: 8 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 256 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/PatchTST/NYCBike-Inflow.yaml b/config/PatchTST/NYCBike-Inflow.yaml new file mode 100644 index 0000000..2c3026c --- /dev/null +++ b/config/PatchTST/NYCBike-Inflow.yaml @@ -0,0 +1,54 @@ +basic: + dataset: NYCBike-InFlow + device: cuda:0 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 256 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + patch_len: 6 + stride: 8 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 256 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/PatchTST/NYCBike-Outflow.yaml b/config/PatchTST/NYCBike-Outflow.yaml new file mode 100644 index 0000000..16eee20 --- /dev/null +++ b/config/PatchTST/NYCBike-Outflow.yaml @@ -0,0 +1,54 @@ +basic: + dataset: NYCBike-OutFlow + device: cuda:0 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 256 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + patch_len: 6 + stride: 8 + d_model: 128 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 256 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/PatchTST/PEMS-BAY.yaml b/config/PatchTST/PEMS-BAY.yaml new file mode 100644 index 0000000..6186db3 --- /dev/null +++ b/config/PatchTST/PEMS-BAY.yaml @@ -0,0 +1,54 @@ +basic: + dataset: PEMS-BAY + device: cuda:0 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 256 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 325 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + d_model: 128 + patch_len: 6 + stride: 8 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 256 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/PatchTST/SolarEnergy.yaml b/config/PatchTST/SolarEnergy.yaml new file mode 100644 index 0000000..28b85b9 --- /dev/null +++ b/config/PatchTST/SolarEnergy.yaml @@ -0,0 +1,54 @@ +basic: + dataset: SolarEnergy + device: cuda:0 + mode: train + model: iTransformer + seed: 2023 + +data: + batch_size: 256 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + normalizer: std + num_nodes: 137 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + activation: gelu + seq_len: 24 + pred_len: 24 + d_model: 128 + patch_len: 6 + stride: 8 + d_ff: 2048 + dropout: 0.1 + e_layers: 2 + n_heads: 8 + output_attention: False + + +train: + batch_size: 256 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.0001 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/dataloader/loader_selector.py b/dataloader/loader_selector.py index f9bf823..88d1e2d 100755 --- a/dataloader/loader_selector.py +++ b/dataloader/loader_selector.py @@ -7,7 +7,7 @@ from dataloader.TSloader import get_dataloader as TS_loader def get_dataloader(config, normalizer, single): - TS_model = ["iTransformer", "HI"] + TS_model = ["iTransformer", "HI", "PatchTST"] model_name = config["basic"]["model"] if model_name in TS_model: return TS_loader(config, normalizer, single) diff --git a/model/MTGNN/MTGNN.py b/model/MTGNN/MTGNN.py new file mode 100644 index 0000000..483a184 --- /dev/null +++ b/model/MTGNN/MTGNN.py @@ -0,0 +1,134 @@ +import torch.nn as nn +from model.MTGNN.layer import * + + +class gtnet(nn.Module): + def __init__(self, gcn_true, buildA_true, gcn_depth, num_nodes, device, predefined_A=None, static_feat=None, dropout=0.3, subgraph_size=20, node_dim=40, dilation_exponential=1, conv_channels=32, residual_channels=32, skip_channels=64, end_channels=128, seq_length=12, in_dim=2, out_dim=12, layers=3, propalpha=0.05, tanhalpha=3, layer_norm_affline=True): + super(gtnet, self).__init__() + self.gcn_true = gcn_true + self.buildA_true = buildA_true + self.num_nodes = num_nodes + self.dropout = dropout + self.predefined_A = predefined_A + self.filter_convs = nn.ModuleList() + self.gate_convs = nn.ModuleList() + self.residual_convs = nn.ModuleList() + self.skip_convs = nn.ModuleList() + self.gconv1 = nn.ModuleList() + self.gconv2 = nn.ModuleList() + self.norm = nn.ModuleList() + self.start_conv = nn.Conv2d(in_channels=in_dim, + out_channels=residual_channels, + kernel_size=(1, 1)) + self.gc = graph_constructor(num_nodes, subgraph_size, node_dim, device, alpha=tanhalpha, static_feat=static_feat) + + self.seq_length = seq_length + kernel_size = 7 + if dilation_exponential>1: + self.receptive_field = int(1+(kernel_size-1)*(dilation_exponential**layers-1)/(dilation_exponential-1)) + else: + self.receptive_field = layers*(kernel_size-1) + 1 + + for i in range(1): + if dilation_exponential>1: + rf_size_i = int(1 + i*(kernel_size-1)*(dilation_exponential**layers-1)/(dilation_exponential-1)) + else: + rf_size_i = i*layers*(kernel_size-1)+1 + new_dilation = 1 + for j in range(1,layers+1): + if dilation_exponential > 1: + rf_size_j = int(rf_size_i + (kernel_size-1)*(dilation_exponential**j-1)/(dilation_exponential-1)) + else: + rf_size_j = rf_size_i+j*(kernel_size-1) + + self.filter_convs.append(dilated_inception(residual_channels, conv_channels, dilation_factor=new_dilation)) + self.gate_convs.append(dilated_inception(residual_channels, conv_channels, dilation_factor=new_dilation)) + self.residual_convs.append(nn.Conv2d(in_channels=conv_channels, + out_channels=residual_channels, + kernel_size=(1, 1))) + if self.seq_length>self.receptive_field: + self.skip_convs.append(nn.Conv2d(in_channels=conv_channels, + out_channels=skip_channels, + kernel_size=(1, self.seq_length-rf_size_j+1))) + else: + self.skip_convs.append(nn.Conv2d(in_channels=conv_channels, + out_channels=skip_channels, + kernel_size=(1, self.receptive_field-rf_size_j+1))) + + if self.gcn_true: + self.gconv1.append(mixprop(conv_channels, residual_channels, gcn_depth, dropout, propalpha)) + self.gconv2.append(mixprop(conv_channels, residual_channels, gcn_depth, dropout, propalpha)) + + if self.seq_length>self.receptive_field: + self.norm.append(LayerNorm((residual_channels, num_nodes, self.seq_length - rf_size_j + 1),elementwise_affine=layer_norm_affline)) + else: + self.norm.append(LayerNorm((residual_channels, num_nodes, self.receptive_field - rf_size_j + 1),elementwise_affine=layer_norm_affline)) + + new_dilation *= dilation_exponential + + self.layers = layers + self.end_conv_1 = nn.Conv2d(in_channels=skip_channels, + out_channels=end_channels, + kernel_size=(1,1), + bias=True) + self.end_conv_2 = nn.Conv2d(in_channels=end_channels, + out_channels=out_dim, + kernel_size=(1,1), + bias=True) + if self.seq_length > self.receptive_field: + self.skip0 = nn.Conv2d(in_channels=in_dim, out_channels=skip_channels, kernel_size=(1, self.seq_length), bias=True) + self.skipE = nn.Conv2d(in_channels=residual_channels, out_channels=skip_channels, kernel_size=(1, self.seq_length-self.receptive_field+1), bias=True) + + else: + self.skip0 = nn.Conv2d(in_channels=in_dim, out_channels=skip_channels, kernel_size=(1, self.receptive_field), bias=True) + self.skipE = nn.Conv2d(in_channels=residual_channels, out_channels=skip_channels, kernel_size=(1, 1), bias=True) + + + self.idx = torch.arange(self.num_nodes).to(device) + + + def forward(self, input, idx=None): + seq_len = input.size(3) + assert seq_len==self.seq_length, 'input sequence length not equal to preset sequence length' + + if self.seq_lengthncvl',(x,A)) + return x.contiguous() + +class dy_nconv(nn.Module): + def __init__(self): + super(dy_nconv,self).__init__() + + def forward(self,x, A): + x = torch.einsum('ncvl,nvwl->ncwl',(x,A)) + return x.contiguous() + +class linear(nn.Module): + def __init__(self,c_in,c_out,bias=True): + super(linear,self).__init__() + self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0,0), stride=(1,1), bias=bias) + + def forward(self,x): + return self.mlp(x) + + +class prop(nn.Module): + def __init__(self,c_in,c_out,gdep,dropout,alpha): + super(prop, self).__init__() + self.nconv = nconv() + self.mlp = linear(c_in,c_out) + self.gdep = gdep + self.dropout = dropout + self.alpha = alpha + + def forward(self,x,adj): + adj = adj + torch.eye(adj.size(0)).to(x.device) + d = adj.sum(1) + h = x + dv = d + a = adj / dv.view(-1, 1) + for i in range(self.gdep): + h = self.alpha*x + (1-self.alpha)*self.nconv(h,a) + ho = self.mlp(h) + return ho + + +class mixprop(nn.Module): + def __init__(self,c_in,c_out,gdep,dropout,alpha): + super(mixprop, self).__init__() + self.nconv = nconv() + self.mlp = linear((gdep+1)*c_in,c_out) + self.gdep = gdep + self.dropout = dropout + self.alpha = alpha + + + def forward(self,x,adj): + adj = adj + torch.eye(adj.size(0)).to(x.device) + d = adj.sum(1) + h = x + out = [h] + a = adj / d.view(-1, 1) + for i in range(self.gdep): + h = self.alpha*x + (1-self.alpha)*self.nconv(h,a) + out.append(h) + ho = torch.cat(out,dim=1) + ho = self.mlp(ho) + return ho + +class dy_mixprop(nn.Module): + def __init__(self,c_in,c_out,gdep,dropout,alpha): + super(dy_mixprop, self).__init__() + self.nconv = dy_nconv() + self.mlp1 = linear((gdep+1)*c_in,c_out) + self.mlp2 = linear((gdep+1)*c_in,c_out) + + self.gdep = gdep + self.dropout = dropout + self.alpha = alpha + self.lin1 = linear(c_in,c_in) + self.lin2 = linear(c_in,c_in) + + + def forward(self,x): + #adj = adj + torch.eye(adj.size(0)).to(x.device) + #d = adj.sum(1) + x1 = torch.tanh(self.lin1(x)) + x2 = torch.tanh(self.lin2(x)) + adj = self.nconv(x1.transpose(2,1),x2) + adj0 = torch.softmax(adj, dim=2) + adj1 = torch.softmax(adj.transpose(2,1), dim=2) + + h = x + out = [h] + for i in range(self.gdep): + h = self.alpha*x + (1-self.alpha)*self.nconv(h,adj0) + out.append(h) + ho = torch.cat(out,dim=1) + ho1 = self.mlp1(ho) + + + h = x + out = [h] + for i in range(self.gdep): + h = self.alpha * x + (1 - self.alpha) * self.nconv(h, adj1) + out.append(h) + ho = torch.cat(out, dim=1) + ho2 = self.mlp2(ho) + + return ho1+ho2 + + + +class dilated_1D(nn.Module): + def __init__(self, cin, cout, dilation_factor=2): + super(dilated_1D, self).__init__() + self.tconv = nn.ModuleList() + self.kernel_set = [2,3,6,7] + self.tconv = nn.Conv2d(cin,cout,(1,7),dilation=(1,dilation_factor)) + + def forward(self,input): + x = self.tconv(input) + return x + +class dilated_inception(nn.Module): + def __init__(self, cin, cout, dilation_factor=2): + super(dilated_inception, self).__init__() + self.tconv = nn.ModuleList() + self.kernel_set = [2,3,6,7] + cout = int(cout/len(self.kernel_set)) + for kern in self.kernel_set: + self.tconv.append(nn.Conv2d(cin,cout,(1,kern),dilation=(1,dilation_factor))) + + def forward(self,input): + x = [] + for i in range(len(self.kernel_set)): + x.append(self.tconv[i](input)) + for i in range(len(self.kernel_set)): + x[i] = x[i][...,-x[-1].size(3):] + x = torch.cat(x,dim=1) + return x + + +class graph_constructor(nn.Module): + def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): + super(graph_constructor, self).__init__() + self.nnodes = nnodes + if static_feat is not None: + xd = static_feat.shape[1] + self.lin1 = nn.Linear(xd, dim) + self.lin2 = nn.Linear(xd, dim) + else: + self.emb1 = nn.Embedding(nnodes, dim) + self.emb2 = nn.Embedding(nnodes, dim) + self.lin1 = nn.Linear(dim,dim) + self.lin2 = nn.Linear(dim,dim) + + self.device = device + self.k = k + self.dim = dim + self.alpha = alpha + self.static_feat = static_feat + + def forward(self, idx): + if self.static_feat is None: + nodevec1 = self.emb1(idx) + nodevec2 = self.emb2(idx) + else: + nodevec1 = self.static_feat[idx,:] + nodevec2 = nodevec1 + + nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1)) + nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2)) + + a = torch.mm(nodevec1, nodevec2.transpose(1,0))-torch.mm(nodevec2, nodevec1.transpose(1,0)) + adj = F.relu(torch.tanh(self.alpha*a)) + mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device) + mask.fill_(float('0')) + s1,t1 = (adj + torch.rand_like(adj)*0.01).topk(self.k,1) + mask.scatter_(1,t1,s1.fill_(1)) + adj = adj*mask + return adj + + def fullA(self, idx): + if self.static_feat is None: + nodevec1 = self.emb1(idx) + nodevec2 = self.emb2(idx) + else: + nodevec1 = self.static_feat[idx,:] + nodevec2 = nodevec1 + + nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1)) + nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2)) + + a = torch.mm(nodevec1, nodevec2.transpose(1,0))-torch.mm(nodevec2, nodevec1.transpose(1,0)) + adj = F.relu(torch.tanh(self.alpha*a)) + return adj + +class graph_global(nn.Module): + def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): + super(graph_global, self).__init__() + self.nnodes = nnodes + self.A = nn.Parameter(torch.randn(nnodes, nnodes).to(device), requires_grad=True).to(device) + + def forward(self, idx): + return F.relu(self.A) + + +class graph_undirected(nn.Module): + def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): + super(graph_undirected, self).__init__() + self.nnodes = nnodes + if static_feat is not None: + xd = static_feat.shape[1] + self.lin1 = nn.Linear(xd, dim) + else: + self.emb1 = nn.Embedding(nnodes, dim) + self.lin1 = nn.Linear(dim,dim) + + self.device = device + self.k = k + self.dim = dim + self.alpha = alpha + self.static_feat = static_feat + + def forward(self, idx): + if self.static_feat is None: + nodevec1 = self.emb1(idx) + nodevec2 = self.emb1(idx) + else: + nodevec1 = self.static_feat[idx,:] + nodevec2 = nodevec1 + + nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1)) + nodevec2 = torch.tanh(self.alpha*self.lin1(nodevec2)) + + a = torch.mm(nodevec1, nodevec2.transpose(1,0)) + adj = F.relu(torch.tanh(self.alpha*a)) + mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device) + mask.fill_(float('0')) + s1,t1 = adj.topk(self.k,1) + mask.scatter_(1,t1,s1.fill_(1)) + adj = adj*mask + return adj + + + +class graph_directed(nn.Module): + def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None): + super(graph_directed, self).__init__() + self.nnodes = nnodes + if static_feat is not None: + xd = static_feat.shape[1] + self.lin1 = nn.Linear(xd, dim) + self.lin2 = nn.Linear(xd, dim) + else: + self.emb1 = nn.Embedding(nnodes, dim) + self.emb2 = nn.Embedding(nnodes, dim) + self.lin1 = nn.Linear(dim,dim) + self.lin2 = nn.Linear(dim,dim) + + self.device = device + self.k = k + self.dim = dim + self.alpha = alpha + self.static_feat = static_feat + + def forward(self, idx): + if self.static_feat is None: + nodevec1 = self.emb1(idx) + nodevec2 = self.emb2(idx) + else: + nodevec1 = self.static_feat[idx,:] + nodevec2 = nodevec1 + + nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1)) + nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2)) + + a = torch.mm(nodevec1, nodevec2.transpose(1,0)) + adj = F.relu(torch.tanh(self.alpha*a)) + mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device) + mask.fill_(float('0')) + s1,t1 = adj.topk(self.k,1) + mask.scatter_(1,t1,s1.fill_(1)) + adj = adj*mask + return adj + + +class LayerNorm(nn.Module): + __constants__ = ['normalized_shape', 'weight', 'bias', 'eps', 'elementwise_affine'] + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super(LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = tuple(normalized_shape) + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter(torch.Tensor(*normalized_shape)) + self.bias = nn.Parameter(torch.Tensor(*normalized_shape)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + self.reset_parameters() + + + def reset_parameters(self): + if self.elementwise_affine: + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, input, idx): + if self.elementwise_affine: + return F.layer_norm(input, tuple(input.shape[1:]), self.weight[:,idx,:], self.bias[:,idx,:], self.eps) + else: + return F.layer_norm(input, tuple(input.shape[1:]), self.weight, self.bias, self.eps) + + def extra_repr(self): + return '{normalized_shape}, eps={eps}, ' \ + 'elementwise_affine={elementwise_affine}'.format(**self.__dict__) \ No newline at end of file diff --git a/model/PatchTST/PatchTST.py b/model/PatchTST/PatchTST.py new file mode 100644 index 0000000..3112030 --- /dev/null +++ b/model/PatchTST/PatchTST.py @@ -0,0 +1,109 @@ +import torch +from torch import nn +from model.PatchTST.layers.Transformer import Encoder, EncoderLayer +from model.PatchTST.layers.SelfAttention import FullAttention, AttentionLayer +from model.PatchTST.layers.Embed import PatchEmbedding + +class Transpose(nn.Module): + def __init__(self, *dims, contiguous=False): + super().__init__() + self.dims, self.contiguous = dims, contiguous + def forward(self, x): + if self.contiguous: return x.transpose(*self.dims).contiguous() + else: return x.transpose(*self.dims) + + +class FlattenHead(nn.Module): + def __init__(self, n_vars, nf, target_window, head_dropout=0): + super().__init__() + self.n_vars = n_vars + self.flatten = nn.Flatten(start_dim=-2) + self.linear = nn.Linear(nf, target_window) + self.dropout = nn.Dropout(head_dropout) + + def forward(self, x): # x: [bs x nvars x d_model x patch_num] + x = self.flatten(x) + x = self.linear(x) + x = self.dropout(x) + return x + + +class Model(nn.Module): + """ + Paper link: https://arxiv.org/pdf/2211.14730.pdf + """ + + def __init__(self, configs): + """ + patch_len: int, patch len for patch_embedding + stride: int, stride for patch_embedding + """ + super().__init__() + self.seq_len = configs['seq_len'] + self.pred_len = configs['pred_len'] + self.patch_len = configs['patch_len'] + self.stride = configs['stride'] + padding = self.stride + + # patching and embedding + self.patch_embedding = PatchEmbedding( + configs['d_model'], self.patch_len, self.stride, padding, configs['dropout']) + + # Encoder + self.encoder = Encoder( + [ + EncoderLayer( + AttentionLayer( + FullAttention(False, attention_dropout=configs['dropout'], + output_attention=False), configs['d_model'], configs['n_heads']), + configs['d_model'], + configs['d_ff'], + dropout=configs['dropout'], + activation=configs['activation'] + ) for l in range(configs['e_layers']) + ], + norm_layer=nn.Sequential(Transpose(1,2), nn.BatchNorm1d(configs.d_model), Transpose(1,2)) + ) + + # Prediction Head + self.head_nf = configs.d_model * \ + int((configs.seq_len - self.patch_len) / self.stride + 2) + self.head = FlattenHead(configs.enc_in, self.head_nf, configs.pred_len, + head_dropout=configs.dropout) + + def forecast(self, x_enc): + # Normalization from Non-stationary Transformer + means = x_enc.mean(1, keepdim=True).detach() + x_enc = x_enc - means + stdev = torch.sqrt( + torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5) + x_enc /= stdev + + # do patching and embedding + x_enc = x_enc.permute(0, 2, 1) + # u: [bs * nvars x patch_num x d_model] + enc_out, n_vars = self.patch_embedding(x_enc) + + # Encoder + # z: [bs * nvars x patch_num x d_model] + enc_out, attns = self.encoder(enc_out) + # z: [bs x nvars x patch_num x d_model] + enc_out = torch.reshape( + enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1])) + # z: [bs x nvars x d_model x patch_num] + enc_out = enc_out.permute(0, 1, 3, 2) + + # Decoder + dec_out = self.head(enc_out) # z: [bs x nvars x target_window] + dec_out = dec_out.permute(0, 2, 1) + + # De-Normalization from Non-stationary Transformer + dec_out = dec_out * \ + (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) + dec_out = dec_out + \ + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1)) + return dec_out + + def forward(self, x_enc): + dec_out = self.forecast(x_enc) + return dec_out[:, -self.pred_len:, :] # [B, L, D] diff --git a/model/PatchTST/layers/Embed.py b/model/PatchTST/layers/Embed.py new file mode 100644 index 0000000..94896e0 --- /dev/null +++ b/model/PatchTST/layers/Embed.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn + +class PatchEmbedding(nn.Module): + def __init__(self, d_model, patch_len, stride, padding, dropout): + super(PatchEmbedding, self).__init__() + # Patching + self.patch_len = patch_len + self.stride = stride + self.padding_patch_layer = nn.ReplicationPad1d((0, padding)) + + # Backbone, Input encoding: projection of feature vectors onto a d-dim vector space + self.value_embedding = nn.Linear(patch_len, d_model, bias=False) + + # Positional embedding + self.position_embedding = PositionalEmbedding(d_model) + + # Residual dropout + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + # do patching + n_vars = x.shape[1] + x = self.padding_patch_layer(x) + x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) + x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])) + # Input encoding + x = self.value_embedding(x) + self.position_embedding(x) + return self.dropout(x), n_vars \ No newline at end of file diff --git a/model/PatchTST/layers/SelfAttention.py b/model/PatchTST/layers/SelfAttention.py new file mode 100644 index 0000000..55b2493 --- /dev/null +++ b/model/PatchTST/layers/SelfAttention.py @@ -0,0 +1,80 @@ +import torch +import torch.nn as nn +import numpy as np +from math import sqrt + +class FullAttention(nn.Module): + def __init__(self, mask_flag=True, scale=None, attention_dropout=0.1, output_attention=False): + super(FullAttention, self).__init__() + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): + B, L, H, E = queries.shape + _, S, _, D = values.shape + scale = self.scale or 1. / sqrt(E) + + scores = torch.einsum("blhe,bshe->bhls", queries, keys) + + if self.mask_flag: + if attn_mask is None: + attn_mask = TriangularCausalMask(B, L, device=queries.device) + + scores.masked_fill_(attn_mask.mask, -np.inf) + + A = self.dropout(torch.softmax(scale * scores, dim=-1)) + V = torch.einsum("bhls,bshd->blhd", A, values) + + if self.output_attention: + return V.contiguous(), A + else: + return V.contiguous(), None + +class AttentionLayer(nn.Module): + def __init__(self, attention, d_model, n_heads, d_keys=None, + d_values=None): + super(AttentionLayer, self).__init__() + + d_keys = d_keys or (d_model // n_heads) + d_values = d_values or (d_model // n_heads) + + self.inner_attention = attention + self.query_projection = nn.Linear(d_model, d_keys * n_heads) + self.key_projection = nn.Linear(d_model, d_keys * n_heads) + self.value_projection = nn.Linear(d_model, d_values * n_heads) + self.out_projection = nn.Linear(d_values * n_heads, d_model) + self.n_heads = n_heads + + def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): + B, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + + queries = self.query_projection(queries).view(B, L, H, -1) + keys = self.key_projection(keys).view(B, S, H, -1) + values = self.value_projection(values).view(B, S, H, -1) + + out, attn = self.inner_attention( + queries, + keys, + values, + attn_mask, + tau=tau, + delta=delta + ) + out = out.view(B, L, -1) + + return self.out_projection(out), attn + + +class TriangularCausalMask: + def __init__(self, B, L, device="cpu"): + mask_shape = [B, 1, L, L] + with torch.no_grad(): + self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) + + @property + def mask(self): + return self._mask \ No newline at end of file diff --git a/model/PatchTST/layers/Transformer.py b/model/PatchTST/layers/Transformer.py new file mode 100644 index 0000000..6116325 --- /dev/null +++ b/model/PatchTST/layers/Transformer.py @@ -0,0 +1,57 @@ +import torch.nn as nn +import torch.nn.functional as F + +class EncoderLayer(nn.Module): + def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): + super(EncoderLayer, self).__init__() + d_ff = d_ff or 4 * d_model + self.attention = attention + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, attn_mask=None, tau=None, delta=None): + new_x, attn = self.attention( + x, x, x, + attn_mask=attn_mask, + tau=tau, delta=delta + ) + x = x + self.dropout(new_x) + + y = x = self.norm1(x) + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + + return self.norm2(x + y), attn + + +class Encoder(nn.Module): + def __init__(self, attn_layers, conv_layers=None, norm_layer=None): + super(Encoder, self).__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None + self.norm = norm_layer + + def forward(self, x, attn_mask=None, tau=None, delta=None): + # x [B, L, D] + attns = [] + if self.conv_layers is not None: + for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)): + delta = delta if i == 0 else None + x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) + x = conv_layer(x) + attns.append(attn) + x, attn = self.attn_layers[-1](x, tau=tau, delta=None) + attns.append(attn) + else: + for attn_layer in self.attn_layers: + x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) + attns.append(attn) + + if self.norm is not None: + x = self.norm(x) + + return x, attns \ No newline at end of file diff --git a/model/model_selector.py b/model/model_selector.py index 7403893..09b7fdc 100755 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -29,6 +29,7 @@ from model.ASTRA.astrav2 import ASTRA as ASTRAv2 from model.ASTRA.astrav3 import ASTRA as ASTRAv3 from model.iTransformer.iTransformer import iTransformer from model.HI.HI import HI +from model.PatchTST.PatchTST import Model as PatchTST @@ -96,3 +97,5 @@ def model_selector(config): return iTransformer(model_config) case "HI": return HI(model_config) + case "PatchTST": + return PatchTST(model_config) diff --git a/train.py b/train.py index 9d58921..dad4609 100644 --- a/train.py +++ b/train.py @@ -45,11 +45,13 @@ def run(config): if __name__ == "__main__": # 指定模型 - model_list = ["HI"] + model_list = ["PatchTST"] # 指定数据集 dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] + # dataset_list = ["AirQuality"] device = "cuda:0" # 指定设备 seed = 2023 # 随机种子 + epochs = 1 for model in model_list: for dataset in dataset_list: config_path = f"./config/{model}/{dataset}.yaml" @@ -57,6 +59,7 @@ if __name__ == "__main__": config = yaml.safe_load(file) config["basic"]["device"] = device config["basic"]["seed"] = seed + config["train"]["epochs"] = epochs print(f"\nRunning {model} on {dataset} with seed {seed} on {device}") print(f"config: {config}") run(config) diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 4bd82a4..80a6672 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -2,7 +2,6 @@ import math import os import time import copy -import psutil import torch from utils.logger import get_logger from utils.loss_function import all_metrics