优化训练器更新提示,删除Informer冗余代码
This commit is contained in:
parent
ce6959a99d
commit
dfc76b8e90
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
@ -9,12 +8,6 @@ from model.Informer.head import TemporalProjectionHead
|
||||||
|
|
||||||
|
|
||||||
class InformerEncoder(nn.Module):
|
class InformerEncoder(nn.Module):
|
||||||
"""
|
|
||||||
Informer Encoder-only
|
|
||||||
- Only uses x
|
|
||||||
- No normalization
|
|
||||||
- Multi-channel friendly
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, configs):
|
def __init__(self, configs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -25,41 +18,20 @@ class InformerEncoder(nn.Module):
|
||||||
Attn = ProbAttention if configs["attn"] == "prob" else FullAttention
|
Attn = ProbAttention if configs["attn"] == "prob" else FullAttention
|
||||||
|
|
||||||
# Embedding
|
# Embedding
|
||||||
self.embedding = DataEmbedding(
|
self.embedding = DataEmbedding(configs["enc_in"], configs["d_model"], configs["dropout"])
|
||||||
configs["enc_in"],
|
|
||||||
configs["d_model"],
|
|
||||||
configs["dropout"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Encoder (Informer)
|
# Encoder(Attn-Conv-Norm)
|
||||||
self.encoder = Encoder(
|
self.encoder = Encoder(
|
||||||
[
|
[EncoderLayer(
|
||||||
EncoderLayer(
|
# Attn
|
||||||
AttentionLayer(
|
AttentionLayer(Attn(False, configs["factor"], configs["dropout"], False),
|
||||||
Attn(
|
configs["d_model"], configs["n_heads"], False),
|
||||||
False,
|
configs["d_model"], configs["d_ff"], configs["dropout"], configs["activation"])
|
||||||
configs["factor"],
|
for _ in range(configs["e_layers"])],
|
||||||
attention_dropout=configs["dropout"],
|
# Conv
|
||||||
output_attention=False,
|
[ConvLayer(configs["d_model"]) for _ in range(configs["e_layers"] - 1)]
|
||||||
),
|
# Norm
|
||||||
configs["d_model"],
|
if configs.get("distil") else None, norm_layer=nn.LayerNorm(configs["d_model"])
|
||||||
configs["n_heads"],
|
|
||||||
mix=False,
|
|
||||||
),
|
|
||||||
configs["d_model"],
|
|
||||||
configs["d_ff"],
|
|
||||||
dropout=configs["dropout"],
|
|
||||||
activation=configs["activation"],
|
|
||||||
)
|
|
||||||
for _ in range(configs["e_layers"])
|
|
||||||
],
|
|
||||||
[
|
|
||||||
ConvLayer(configs["d_model"])
|
|
||||||
for _ in range(configs["e_layers"] - 1)
|
|
||||||
]
|
|
||||||
if configs.get("distil", False)
|
|
||||||
else None,
|
|
||||||
norm_layer=nn.LayerNorm(configs["d_model"]),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Forecast Head
|
# Forecast Head
|
||||||
|
|
@ -70,9 +42,6 @@ class InformerEncoder(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x_enc):
|
def forward(self, x_enc):
|
||||||
"""
|
|
||||||
x_enc: [B, L, C]
|
|
||||||
"""
|
|
||||||
x = self.embedding(x_enc)
|
x = self.embedding(x_enc)
|
||||||
x, _ = self.encoder(x)
|
x, _ = self.encoder(x)
|
||||||
out = self.head(x)
|
out = self.head(x)
|
||||||
|
|
|
||||||
12
train.py
12
train.py
|
|
@ -12,9 +12,9 @@ def read_config(config_path):
|
||||||
config = yaml.safe_load(file)
|
config = yaml.safe_load(file)
|
||||||
|
|
||||||
# 全局配置
|
# 全局配置
|
||||||
device = "cpu" # 指定设备为cuda:0
|
device = "cuda:0" # 指定设备为cuda:0
|
||||||
seed = 2023 # 随机种子
|
seed = 2023 # 随机种子
|
||||||
epochs = 1 # 训练轮数
|
epochs = 100 # 训练轮数
|
||||||
|
|
||||||
# 拷贝项
|
# 拷贝项
|
||||||
config["basic"]["device"] = device
|
config["basic"]["device"] = device
|
||||||
|
|
@ -102,11 +102,11 @@ def main(model_list, data, debug=False):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 调试用
|
# 调试用
|
||||||
# model_list = ["iTransformer", "PatchTST", "HI"]
|
# model_list = ["iTransformer", "PatchTST", "HI"]
|
||||||
model_list = ["iTransformer", "Informer"]
|
model_list = ["Informer"]
|
||||||
# model_list = ["PatchTST"]
|
# model_list = ["PatchTST"]
|
||||||
# dataset_list = ["AirQuality"]
|
# dataset_list = ["AirQuality"]
|
||||||
dataset_list = ["SolarEnergy"]
|
# dataset_list = ["SolarEnergy"]
|
||||||
# dataset_list = ["BJTaxi-InFlow", "BJTaxi-OutFlow"]
|
# dataset_list = ["BJTaxi-InFlow", "BJTaxi-OutFlow"]
|
||||||
# dataset_list = ["AirQuality", "PEMS-BAY", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"]
|
dataset_list = ["SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"]
|
||||||
# dataset_list = ["BJTaxi-OutFlow"]
|
# dataset_list = ["BJTaxi-OutFlow"]
|
||||||
main(model_list, dataset_list, debug=True)
|
main(model_list, dataset_list, debug=False)
|
||||||
|
|
|
||||||
|
|
@ -76,10 +76,14 @@ class Trainer:
|
||||||
}
|
}
|
||||||
|
|
||||||
if losses["train"] > 1e6: self.logger.warning("Gradient explosion detected"); break
|
if losses["train"] > 1e6: self.logger.warning("Gradient explosion detected"); break
|
||||||
if losses["val"] < best: best, patience, best_w = losses["val"], 0, copy.deepcopy(self.model.state_dict())
|
if losses["val"] < best:
|
||||||
|
best, patience, best_w = losses["val"], 0, copy.deepcopy(self.model.state_dict())
|
||||||
|
self.logger.info(f"Best model updated at Epoch {epoch:02d}#")
|
||||||
else: patience += 1
|
else: patience += 1
|
||||||
if self.args["early_stop"] and patience == self.args["early_stop_patience"]: break
|
if self.args["early_stop"] and patience == self.args["early_stop_patience"]: break
|
||||||
if losses["test"] < best_test: best_test, best_test_w = losses["test"], copy.deepcopy(self.model.state_dict())
|
if losses["test"] < best_test:
|
||||||
|
best_test, best_test_w = losses["test"], copy.deepcopy(self.model.state_dict())
|
||||||
|
self.logger.info(f"Best test model saved at Epoch {epoch:02d}#")
|
||||||
|
|
||||||
if not self.args["debug"]:
|
if not self.args["debug"]:
|
||||||
torch.save(best_w, self.best_path)
|
torch.save(best_w, self.best_path)
|
||||||
|
|
|
||||||
|
|
@ -71,10 +71,14 @@ class Trainer:
|
||||||
}
|
}
|
||||||
|
|
||||||
if losses["train"] > 1e6: self.logger.warning("Gradient explosion detected"); break
|
if losses["train"] > 1e6: self.logger.warning("Gradient explosion detected"); break
|
||||||
if losses["val"] < best: best, patience, best_w = losses["val"], 0, copy.deepcopy(self.model.state_dict())
|
if losses["val"] < best:
|
||||||
|
best, patience, best_w = losses["val"], 0, copy.deepcopy(self.model.state_dict())
|
||||||
|
self.logger.info(f"Best model updated at Epoch {epoch:02d}#")
|
||||||
else: patience += 1
|
else: patience += 1
|
||||||
if self.args["early_stop"] and patience == self.args["early_stop_patience"]: break
|
if self.args["early_stop"] and patience == self.args["early_stop_patience"]: break
|
||||||
if losses["test"] < best_test: best_test, best_test_w = losses["test"], copy.deepcopy(self.model.state_dict())
|
if losses["test"] < best_test:
|
||||||
|
best_test, best_test_w = losses["test"], copy.deepcopy(self.model.state_dict())
|
||||||
|
self.logger.info(f"Best test model saved at Epoch {epoch:02d}#")
|
||||||
|
|
||||||
if not self.args["debug"]:
|
if not self.args["debug"]:
|
||||||
torch.save(best_w, self.best_path)
|
torch.save(best_w, self.best_path)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue