优化训练器更新提示,删除Informer冗余代码

This commit is contained in:
czzhangheng 2025-12-22 17:25:52 +08:00
parent ce6959a99d
commit dfc76b8e90
4 changed files with 30 additions and 53 deletions

View File

@ -1,4 +1,3 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -9,12 +8,6 @@ from model.Informer.head import TemporalProjectionHead
class InformerEncoder(nn.Module): class InformerEncoder(nn.Module):
"""
Informer Encoder-only
- Only uses x
- No normalization
- Multi-channel friendly
"""
def __init__(self, configs): def __init__(self, configs):
super().__init__() super().__init__()
@ -25,41 +18,20 @@ class InformerEncoder(nn.Module):
Attn = ProbAttention if configs["attn"] == "prob" else FullAttention Attn = ProbAttention if configs["attn"] == "prob" else FullAttention
# Embedding # Embedding
self.embedding = DataEmbedding( self.embedding = DataEmbedding(configs["enc_in"], configs["d_model"], configs["dropout"])
configs["enc_in"],
configs["d_model"],
configs["dropout"],
)
# Encoder (Informer) # Encoder(Attn-Conv-Norm)
self.encoder = Encoder( self.encoder = Encoder(
[ [EncoderLayer(
EncoderLayer( # Attn
AttentionLayer( AttentionLayer(Attn(False, configs["factor"], configs["dropout"], False),
Attn( configs["d_model"], configs["n_heads"], False),
False, configs["d_model"], configs["d_ff"], configs["dropout"], configs["activation"])
configs["factor"], for _ in range(configs["e_layers"])],
attention_dropout=configs["dropout"], # Conv
output_attention=False, [ConvLayer(configs["d_model"]) for _ in range(configs["e_layers"] - 1)]
), # Norm
configs["d_model"], if configs.get("distil") else None, norm_layer=nn.LayerNorm(configs["d_model"])
configs["n_heads"],
mix=False,
),
configs["d_model"],
configs["d_ff"],
dropout=configs["dropout"],
activation=configs["activation"],
)
for _ in range(configs["e_layers"])
],
[
ConvLayer(configs["d_model"])
for _ in range(configs["e_layers"] - 1)
]
if configs.get("distil", False)
else None,
norm_layer=nn.LayerNorm(configs["d_model"]),
) )
# Forecast Head # Forecast Head
@ -70,9 +42,6 @@ class InformerEncoder(nn.Module):
) )
def forward(self, x_enc): def forward(self, x_enc):
"""
x_enc: [B, L, C]
"""
x = self.embedding(x_enc) x = self.embedding(x_enc)
x, _ = self.encoder(x) x, _ = self.encoder(x)
out = self.head(x) out = self.head(x)

View File

@ -12,9 +12,9 @@ def read_config(config_path):
config = yaml.safe_load(file) config = yaml.safe_load(file)
# 全局配置 # 全局配置
device = "cpu" # 指定设备为cuda:0 device = "cuda:0" # 指定设备为cuda:0
seed = 2023 # 随机种子 seed = 2023 # 随机种子
epochs = 1 # 训练轮数 epochs = 100 # 训练轮数
# 拷贝项 # 拷贝项
config["basic"]["device"] = device config["basic"]["device"] = device
@ -102,11 +102,11 @@ def main(model_list, data, debug=False):
if __name__ == "__main__": if __name__ == "__main__":
# 调试用 # 调试用
# model_list = ["iTransformer", "PatchTST", "HI"] # model_list = ["iTransformer", "PatchTST", "HI"]
model_list = ["iTransformer", "Informer"] model_list = ["Informer"]
# model_list = ["PatchTST"] # model_list = ["PatchTST"]
# dataset_list = ["AirQuality"] # dataset_list = ["AirQuality"]
dataset_list = ["SolarEnergy"] # dataset_list = ["SolarEnergy"]
# dataset_list = ["BJTaxi-InFlow", "BJTaxi-OutFlow"] # dataset_list = ["BJTaxi-InFlow", "BJTaxi-OutFlow"]
# dataset_list = ["AirQuality", "PEMS-BAY", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"] dataset_list = ["SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"]
# dataset_list = ["BJTaxi-OutFlow"] # dataset_list = ["BJTaxi-OutFlow"]
main(model_list, dataset_list, debug=True) main(model_list, dataset_list, debug=False)

View File

@ -76,10 +76,14 @@ class Trainer:
} }
if losses["train"] > 1e6: self.logger.warning("Gradient explosion detected"); break if losses["train"] > 1e6: self.logger.warning("Gradient explosion detected"); break
if losses["val"] < best: best, patience, best_w = losses["val"], 0, copy.deepcopy(self.model.state_dict()) if losses["val"] < best:
best, patience, best_w = losses["val"], 0, copy.deepcopy(self.model.state_dict())
self.logger.info(f"Best model updated at Epoch {epoch:02d}#")
else: patience += 1 else: patience += 1
if self.args["early_stop"] and patience == self.args["early_stop_patience"]: break if self.args["early_stop"] and patience == self.args["early_stop_patience"]: break
if losses["test"] < best_test: best_test, best_test_w = losses["test"], copy.deepcopy(self.model.state_dict()) if losses["test"] < best_test:
best_test, best_test_w = losses["test"], copy.deepcopy(self.model.state_dict())
self.logger.info(f"Best test model saved at Epoch {epoch:02d}#")
if not self.args["debug"]: if not self.args["debug"]:
torch.save(best_w, self.best_path) torch.save(best_w, self.best_path)

View File

@ -71,10 +71,14 @@ class Trainer:
} }
if losses["train"] > 1e6: self.logger.warning("Gradient explosion detected"); break if losses["train"] > 1e6: self.logger.warning("Gradient explosion detected"); break
if losses["val"] < best: best, patience, best_w = losses["val"], 0, copy.deepcopy(self.model.state_dict()) if losses["val"] < best:
best, patience, best_w = losses["val"], 0, copy.deepcopy(self.model.state_dict())
self.logger.info(f"Best model updated at Epoch {epoch:02d}#")
else: patience += 1 else: patience += 1
if self.args["early_stop"] and patience == self.args["early_stop_patience"]: break if self.args["early_stop"] and patience == self.args["early_stop_patience"]: break
if losses["test"] < best_test: best_test, best_test_w = losses["test"], copy.deepcopy(self.model.state_dict()) if losses["test"] < best_test:
best_test, best_test_w = losses["test"], copy.deepcopy(self.model.state_dict())
self.logger.info(f"Best test model saved at Epoch {epoch:02d}#")
if not self.args["debug"]: if not self.args["debug"]:
torch.save(best_w, self.best_path) torch.save(best_w, self.best_path)