From dfc76b8e9064e0a3c6e1369bdfb58a1c95954e24 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Mon, 22 Dec 2025 17:25:52 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E8=AE=AD=E7=BB=83=E5=99=A8?= =?UTF-8?q?=E6=9B=B4=E6=96=B0=E6=8F=90=E7=A4=BA=EF=BC=8C=E5=88=A0=E9=99=A4?= =?UTF-8?q?Informer=E5=86=97=E4=BD=99=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/Informer/model.py | 55 +++++++++-------------------------------- train.py | 12 ++++----- trainer/TSTrainer.py | 8 ++++-- trainer/Trainer.py | 8 ++++-- 4 files changed, 30 insertions(+), 53 deletions(-) diff --git a/model/Informer/model.py b/model/Informer/model.py index c7cb469..8cf072a 100644 --- a/model/Informer/model.py +++ b/model/Informer/model.py @@ -1,4 +1,3 @@ - import torch import torch.nn as nn @@ -9,12 +8,6 @@ from model.Informer.head import TemporalProjectionHead class InformerEncoder(nn.Module): - """ - Informer Encoder-only - - Only uses x - - No normalization - - Multi-channel friendly - """ def __init__(self, configs): super().__init__() @@ -25,41 +18,20 @@ class InformerEncoder(nn.Module): Attn = ProbAttention if configs["attn"] == "prob" else FullAttention # Embedding - self.embedding = DataEmbedding( - configs["enc_in"], - configs["d_model"], - configs["dropout"], - ) + self.embedding = DataEmbedding(configs["enc_in"], configs["d_model"], configs["dropout"]) - # Encoder (Informer) + # Encoder(Attn-Conv-Norm) self.encoder = Encoder( - [ - EncoderLayer( - AttentionLayer( - Attn( - False, - configs["factor"], - attention_dropout=configs["dropout"], - output_attention=False, - ), - configs["d_model"], - configs["n_heads"], - mix=False, - ), - configs["d_model"], - configs["d_ff"], - dropout=configs["dropout"], - activation=configs["activation"], - ) - for _ in range(configs["e_layers"]) - ], - [ - ConvLayer(configs["d_model"]) - for _ in range(configs["e_layers"] - 1) - ] - if configs.get("distil", False) - else None, - norm_layer=nn.LayerNorm(configs["d_model"]), + [EncoderLayer( + # Attn + AttentionLayer(Attn(False, configs["factor"], configs["dropout"], False), + configs["d_model"], configs["n_heads"], False), + configs["d_model"], configs["d_ff"], configs["dropout"], configs["activation"]) + for _ in range(configs["e_layers"])], + # Conv + [ConvLayer(configs["d_model"]) for _ in range(configs["e_layers"] - 1)] + # Norm + if configs.get("distil") else None, norm_layer=nn.LayerNorm(configs["d_model"]) ) # Forecast Head @@ -70,9 +42,6 @@ class InformerEncoder(nn.Module): ) def forward(self, x_enc): - """ - x_enc: [B, L, C] - """ x = self.embedding(x_enc) x, _ = self.encoder(x) out = self.head(x) diff --git a/train.py b/train.py index 97f9bc8..55662ba 100644 --- a/train.py +++ b/train.py @@ -12,9 +12,9 @@ def read_config(config_path): config = yaml.safe_load(file) # 全局配置 - device = "cpu" # 指定设备为cuda:0 + device = "cuda:0" # 指定设备为cuda:0 seed = 2023 # 随机种子 - epochs = 1 # 训练轮数 + epochs = 100 # 训练轮数 # 拷贝项 config["basic"]["device"] = device @@ -102,11 +102,11 @@ def main(model_list, data, debug=False): if __name__ == "__main__": # 调试用 # model_list = ["iTransformer", "PatchTST", "HI"] - model_list = ["iTransformer", "Informer"] + model_list = ["Informer"] # model_list = ["PatchTST"] # dataset_list = ["AirQuality"] - dataset_list = ["SolarEnergy"] + # dataset_list = ["SolarEnergy"] # dataset_list = ["BJTaxi-InFlow", "BJTaxi-OutFlow"] - # dataset_list = ["AirQuality", "PEMS-BAY", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"] + dataset_list = ["SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"] # dataset_list = ["BJTaxi-OutFlow"] - main(model_list, dataset_list, debug=True) + main(model_list, dataset_list, debug=False) diff --git a/trainer/TSTrainer.py b/trainer/TSTrainer.py index 11cd431..d8fdc91 100755 --- a/trainer/TSTrainer.py +++ b/trainer/TSTrainer.py @@ -76,10 +76,14 @@ class Trainer: } if losses["train"] > 1e6: self.logger.warning("Gradient explosion detected"); break - if losses["val"] < best: best, patience, best_w = losses["val"], 0, copy.deepcopy(self.model.state_dict()) + if losses["val"] < best: + best, patience, best_w = losses["val"], 0, copy.deepcopy(self.model.state_dict()) + self.logger.info(f"Best model updated at Epoch {epoch:02d}#") else: patience += 1 if self.args["early_stop"] and patience == self.args["early_stop_patience"]: break - if losses["test"] < best_test: best_test, best_test_w = losses["test"], copy.deepcopy(self.model.state_dict()) + if losses["test"] < best_test: + best_test, best_test_w = losses["test"], copy.deepcopy(self.model.state_dict()) + self.logger.info(f"Best test model saved at Epoch {epoch:02d}#") if not self.args["debug"]: torch.save(best_w, self.best_path) diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 7036e26..c46d38a 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -71,10 +71,14 @@ class Trainer: } if losses["train"] > 1e6: self.logger.warning("Gradient explosion detected"); break - if losses["val"] < best: best, patience, best_w = losses["val"], 0, copy.deepcopy(self.model.state_dict()) + if losses["val"] < best: + best, patience, best_w = losses["val"], 0, copy.deepcopy(self.model.state_dict()) + self.logger.info(f"Best model updated at Epoch {epoch:02d}#") else: patience += 1 if self.args["early_stop"] and patience == self.args["early_stop_patience"]: break - if losses["test"] < best_test: best_test, best_test_w = losses["test"], copy.deepcopy(self.model.state_dict()) + if losses["test"] < best_test: + best_test, best_test_w = losses["test"], copy.deepcopy(self.model.state_dict()) + self.logger.info(f"Best test model saved at Epoch {epoch:02d}#") if not self.args["debug"]: torch.save(best_w, self.best_path)