diff --git a/config/FPT/AirQuality.yaml b/config/FPT/AirQuality.yaml new file mode 100644 index 0000000..0604938 --- /dev/null +++ b/config/FPT/AirQuality.yaml @@ -0,0 +1,51 @@ +basic: + dataset: AirQuality + device: cuda:0 + mode: train + model: FPT + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + normalizer: std + num_nodes: 35 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_model: 768 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 6 + n_heads: 1 + num_nodes: 35 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 6 + plot: false + weight_decay: 0 diff --git a/config/FPT/BJTaxi-InFlow.yaml b/config/FPT/BJTaxi-InFlow.yaml new file mode 100644 index 0000000..18abb67 --- /dev/null +++ b/config/FPT/BJTaxi-InFlow.yaml @@ -0,0 +1,51 @@ +basic: + dataset: BJTaxi-InFlow + device: cuda:0 + mode: train + model: FPT + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_model: 768 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 1024 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + +train: + batch_size: 32 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/config/FPT/BJTaxi-OutFlow.yaml b/config/FPT/BJTaxi-OutFlow.yaml new file mode 100644 index 0000000..3e6765a --- /dev/null +++ b/config/FPT/BJTaxi-OutFlow.yaml @@ -0,0 +1,51 @@ +basic: + dataset: BJTaxi-OutFlow + device: cuda:0 + mode: train + model: FPT + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_model: 768 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 1024 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + +train: + batch_size: 32 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/config/FPT/METR-LA.yaml b/config/FPT/METR-LA.yaml new file mode 100644 index 0000000..0c22dcb --- /dev/null +++ b/config/FPT/METR-LA.yaml @@ -0,0 +1,52 @@ +basic: + dataset: METR-LA + device: cuda:0 + mode: train + model: FPT + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 207 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_model: 768 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 207 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 diff --git a/config/FPT/NYCBike-InFlow.yaml b/config/FPT/NYCBike-InFlow.yaml new file mode 100644 index 0000000..41a8c8b --- /dev/null +++ b/config/FPT/NYCBike-InFlow.yaml @@ -0,0 +1,51 @@ +basic: + dataset: NYCBike-InFlow + device: cuda:0 + mode: train + model: FPT + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_model: 768 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 128 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + +train: + batch_size: 32 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/config/FPT/NYCBike-OutFlow.yaml b/config/FPT/NYCBike-OutFlow.yaml new file mode 100644 index 0000000..cc52b1a --- /dev/null +++ b/config/FPT/NYCBike-OutFlow.yaml @@ -0,0 +1,51 @@ +basic: + dataset: NYCBike-OutFlow + device: cuda:0 + mode: train + model: FPT + seed: 2023 + +data: + batch_size: 32 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_model: 768 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 128 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + +train: + batch_size: 32 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/config/FPT/PEMS-BAY.yaml b/config/FPT/PEMS-BAY.yaml new file mode 100755 index 0000000..efe4d7c --- /dev/null +++ b/config/FPT/PEMS-BAY.yaml @@ -0,0 +1,51 @@ +basic: + dataset: PEMS-BAY + device: cuda:0 + mode: train + model: FPT + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 325 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_model: 768 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 325 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/config/FPT/SolarEnergy.yaml b/config/FPT/SolarEnergy.yaml new file mode 100644 index 0000000..fe1ea22 --- /dev/null +++ b/config/FPT/SolarEnergy.yaml @@ -0,0 +1,51 @@ +basic: + dataset: SolarEnergy + device: cuda:0 + mode: train + model: FPT + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 137 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + d_model: 768 + gpt_layers: 9 + gpt_path: ./GPT-2 + input_dim: 1 + n_heads: 1 + num_nodes: 137 + patch_len: 6 + pred_len: 24 + seq_len: 24 + stride: 7 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 100 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + weight_decay: 0 diff --git a/model/ASTRA/astra.py b/model/ASTRA/astra.py index 71d4ee9..f0d32e5 100644 --- a/model/ASTRA/astra.py +++ b/model/ASTRA/astra.py @@ -206,7 +206,6 @@ class ASTRA(nn.Module): enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, d_model * input_dim) # 应用图增强编码器(自动生成图结构) graph_enhanced = self.graph_encoder(enc_out) # (B, N, K * hidden_dim) - # 特征融合 - 现在两个张量都是三维的 [B, N, d_model] enc_out = torch.cat([enc_out, graph_enhanced], dim=-1) enc_out = self.feature_fusion(enc_out) diff --git a/model/FPT/fpt.py b/model/FPT/fpt.py new file mode 100644 index 0000000..941da6d --- /dev/null +++ b/model/FPT/fpt.py @@ -0,0 +1,45 @@ +import torch.nn as nn +from transformers.models.gpt2.modeling_gpt2 import GPT2Model +from einops import rearrange + +class fpt(nn.Module): + def __init__(self, configs): + super(fpt, self).__init__() + self.patch_len = configs['patch_len'] + self.stride = configs['stride'] + self.input_dim = configs['input_dim'] + self.seq_len = configs['seq_len'] + self.pred_len = configs['pred_len'] + self.gpt_layers = configs['gpt_layers'] # 使用的GPT2层数 + self.d_model = configs['d_model'] + self.gpt_path = configs['gpt_path'] + + self.patch_num = int((self.seq_len - self.patch_len) / self.stride + 2) # 补丁数量 + self.padding_patch_layer = nn.ReplicationPad1d((0, self.stride)) + + self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True) + self.gpts.h = self.gpts.h[:self.gpt_layers] + for i, (name, param) in enumerate(self.gpts.named_parameters()): + if 'wpe' in name: + param.requires_grad = True + else: + param.requires_grad = False + + self.in_layer = nn.Linear(self.patch_len, self.d_model) + self.out_layer = nn.Linear(self.d_model * self.patch_num, self.pred_len) + + def forward(self, x): + B, L, M = x.shape + x = x[..., :self.input_dim] + x = rearrange(x, 'b l m -> b m l') + + x = self.padding_patch_layer(x) + x = x.unfold(dimension = -1, size = self.patch_len, step = self.stride) + x = rearrange(x, 'b m n p -> (b m) n p') + + outputs = self.in_layer(x) + outputs = self.gpts(inputs_embeds=outputs).last_hidden_state + outputs = self.out_layer(outputs.reshape(B*M, -1)) + outputs = rearrange(outputs, '(b m) l -> b l m', b = B) + return outputs + diff --git a/model/FPT/model_config.json b/model/FPT/model_config.json new file mode 100644 index 0000000..a7d040c --- /dev/null +++ b/model/FPT/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "FPT", + "module": "model.FPT.fpt", + "entry": "fpt" + } +] \ No newline at end of file diff --git a/train.py b/train.py index 40b3bcb..7242ac0 100644 --- a/train.py +++ b/train.py @@ -90,9 +90,9 @@ def main(model, data, debug=False): if __name__ == "__main__": # 调试用 # model_list = ["iTransformer", "PatchTST", "HI"] - # model_list = ["ASTRA_v3"] - model_list = ["PatchTST"] - dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] + model_list = ["FPT"] + # model_list = ["PatchTST"] + dataset_list = ["METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] # dataset_list = ["AirQuality"] # dataset_list = ["METR-LA"] main(model_list, dataset_list, debug = False) \ No newline at end of file diff --git a/trainer/trainer_selector.py b/trainer/trainer_selector.py index 17aa81d..24d8a10 100755 --- a/trainer/trainer_selector.py +++ b/trainer/trainer_selector.py @@ -20,7 +20,7 @@ def select_trainer( scaler, args, lr_scheduler ) - if model_name in {"HI", "PatchTST", "iTransformer"}: + if model_name in {"HI", "PatchTST", "iTransformer", "FPT"}: return TSTrainer(*base_args) trainer_map = {