feat: 重构训练器和配置结构,优化模型训练流程
refactor(trainer): 重构Trainer类结构,拆分初始化方法 perf(trainer): 优化训练循环和评估逻辑 style(config): 统一配置文件命名和结构 fix(trainer): 修复形状检查逻辑和调试模式处理 docs: 更新README和注释说明
This commit is contained in:
parent
3095b7435b
commit
56b09ea8ac
|
|
@ -2,7 +2,7 @@ basic:
|
|||
dataset: METR-LA
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: AEPSA_v2
|
||||
model: ASTRA_v2
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
|
|
@ -2,7 +2,7 @@ basic:
|
|||
dataset: PEMS-BAY
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: AEPSA_v3
|
||||
model: ASTRA_v2
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
|
|
@ -2,7 +2,7 @@ basic:
|
|||
dataset: SolarEnergy
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: AEPSA_v2
|
||||
model: ASTRA_v2
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
basic:
|
||||
dataset: AirQuality
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: ASTRA_v3
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 16
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
input_dim: 6
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 35
|
||||
steps_per_day: 24
|
||||
test_ratio: 0.2
|
||||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
d_ff: 128
|
||||
d_model: 64
|
||||
dropout: 0.2
|
||||
gpt_layers: 9
|
||||
gpt_path: ./GPT-2
|
||||
input_dim: 6
|
||||
n_heads: 1
|
||||
num_nodes: 35
|
||||
patch_len: 6
|
||||
pred_len: 24
|
||||
seq_len: 24
|
||||
stride: 7
|
||||
word_num: 1000
|
||||
|
||||
train:
|
||||
batch_size: 16
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_step: 100
|
||||
loss_func: mae
|
||||
lr_decay: true
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 6
|
||||
plot: false
|
||||
weight_decay: 0
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
basic:
|
||||
dataset: BJTaxi-InFlow
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: ASTRA_v3
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 32
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 1024
|
||||
steps_per_day: 48
|
||||
test_ratio: 0.2
|
||||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
d_ff: 128
|
||||
d_model: 64
|
||||
dropout: 0.2
|
||||
gpt_layers: 9
|
||||
gpt_path: ./GPT-2
|
||||
input_dim: 1
|
||||
n_heads: 1
|
||||
num_nodes: 1024
|
||||
patch_len: 6
|
||||
pred_len: 24
|
||||
seq_len: 24
|
||||
stride: 7
|
||||
word_num: 1000
|
||||
|
||||
train:
|
||||
batch_size: 32
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_step: 100
|
||||
loss_func: mae
|
||||
lr_decay: true
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
plot: false
|
||||
weight_decay: 0
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
basic:
|
||||
dataset: BJTaxi-OutFlow
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: ASTRA_v3
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 32
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 1024
|
||||
steps_per_day: 48
|
||||
test_ratio: 0.2
|
||||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
d_ff: 128
|
||||
d_model: 64
|
||||
dropout: 0.2
|
||||
gpt_layers: 9
|
||||
gpt_path: ./GPT-2
|
||||
input_dim: 1
|
||||
n_heads: 1
|
||||
num_nodes: 1024
|
||||
patch_len: 6
|
||||
pred_len: 24
|
||||
seq_len: 24
|
||||
stride: 7
|
||||
word_num: 1000
|
||||
|
||||
train:
|
||||
batch_size: 32
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_step: 100
|
||||
loss_func: mae
|
||||
lr_decay: true
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
plot: false
|
||||
weight_decay: 0
|
||||
|
|
@ -2,7 +2,7 @@ basic:
|
|||
dataset: METR-LA
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: AEPSA_v3
|
||||
model: ASTRA_v3
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
|
|
@ -19,11 +19,9 @@ data:
|
|||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
chebyshev_order: 3
|
||||
d_ff: 128
|
||||
d_model: 64
|
||||
dropout: 0.2
|
||||
graph_hidden_dim: 32
|
||||
gpt_layers: 9
|
||||
gpt_path: ./GPT-2
|
||||
input_dim: 1
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
basic:
|
||||
dataset: NYCBike-InFlow
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: ASTRA_v3
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 32
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 128
|
||||
steps_per_day: 48
|
||||
test_ratio: 0.2
|
||||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
d_ff: 128
|
||||
d_model: 64
|
||||
dropout: 0.2
|
||||
gpt_layers: 9
|
||||
gpt_path: ./GPT-2
|
||||
input_dim: 1
|
||||
n_heads: 1
|
||||
num_nodes: 128
|
||||
patch_len: 6
|
||||
pred_len: 24
|
||||
seq_len: 24
|
||||
stride: 7
|
||||
word_num: 1000
|
||||
|
||||
train:
|
||||
batch_size: 32
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_step: 100
|
||||
loss_func: mae
|
||||
lr_decay: true
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
plot: false
|
||||
weight_decay: 0
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
basic:
|
||||
dataset: NYCBike-OutFlow
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: ASTRA_v3
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 32
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 128
|
||||
steps_per_day: 48
|
||||
test_ratio: 0.2
|
||||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
d_ff: 128
|
||||
d_model: 64
|
||||
dropout: 0.2
|
||||
gpt_layers: 9
|
||||
gpt_path: ./GPT-2
|
||||
input_dim: 1
|
||||
n_heads: 1
|
||||
num_nodes: 128
|
||||
patch_len: 6
|
||||
pred_len: 24
|
||||
seq_len: 24
|
||||
stride: 7
|
||||
word_num: 1000
|
||||
|
||||
train:
|
||||
batch_size: 32
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_step: 100
|
||||
loss_func: mae
|
||||
lr_decay: true
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
plot: false
|
||||
weight_decay: 0
|
||||
|
|
@ -2,7 +2,7 @@ basic:
|
|||
dataset: PEMS-BAY
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: AEPSA_v2
|
||||
model: ASTRA_v3
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
|
|
@ -19,11 +19,9 @@ data:
|
|||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
chebyshev_order: 3
|
||||
d_ff: 128
|
||||
d_model: 64
|
||||
dropout: 0.2
|
||||
graph_hidden_dim: 32
|
||||
gpt_layers: 9
|
||||
gpt_path: ./GPT-2
|
||||
input_dim: 1
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
basic:
|
||||
dataset: SolarEnergy
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: ASTRA_v3
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
batch_size: 64
|
||||
column_wise: false
|
||||
days_per_week: 7
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 137
|
||||
steps_per_day: 24
|
||||
test_ratio: 0.2
|
||||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
d_ff: 128
|
||||
d_model: 64
|
||||
dropout: 0.2
|
||||
gpt_layers: 9
|
||||
gpt_path: ./GPT-2
|
||||
input_dim: 1
|
||||
n_heads: 1
|
||||
num_nodes: 137
|
||||
patch_len: 6
|
||||
pred_len: 24
|
||||
seq_len: 24
|
||||
stride: 7
|
||||
word_num: 1000
|
||||
|
||||
train:
|
||||
batch_size: 64
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_step: 100
|
||||
loss_func: mae
|
||||
lr_decay: true
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
plot: false
|
||||
weight_decay: 0
|
||||
|
|
@ -13,8 +13,8 @@ data:
|
|||
input_dim: 6
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 12
|
||||
steps_per_day: 288
|
||||
num_nodes: 35
|
||||
steps_per_day: 24
|
||||
test_ratio: 0.2
|
||||
val_ratio: 0.2
|
||||
|
||||
|
|
|
|||
|
|
@ -184,7 +184,7 @@ class ASTRA(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
# 数据处理
|
||||
x = x[..., :1] # [B,T,N,1]
|
||||
x = x[..., :self.input_dim]
|
||||
x_enc = rearrange(x, 'b t n c -> b n c t') # [B,N,1,T]
|
||||
|
||||
# 图编码
|
||||
|
|
@ -203,7 +203,9 @@ class ASTRA(nn.Module):
|
|||
dec_out = self.out_mlp(X_enc) # [B,N,pred_len]
|
||||
|
||||
# 维度调整
|
||||
outputs = dec_out.unsqueeze(dim=-1) # [B,N,pred_len,1]
|
||||
outputs = outputs.permute(0, 2, 1, 3) # [B,pred_len,N,1]
|
||||
dec_out = self.out_mlp(enc_out)
|
||||
outputs = dec_out.unsqueeze(dim=-1)
|
||||
outputs = outputs.repeat(1, 1, 1, self.input_dim)
|
||||
outputs = outputs.permute(0,2,1,3)
|
||||
|
||||
return outputs
|
||||
23
train.py
23
train.py
|
|
@ -60,25 +60,17 @@ def run(config):
|
|||
case _:
|
||||
raise ValueError(f"Unsupported mode: {config['basic']['mode']}")
|
||||
|
||||
def main(debug=False):
|
||||
# 指定模型
|
||||
model_list = ["iTransformer"]
|
||||
# 指定数据集
|
||||
# dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"]
|
||||
# dataset_list = ["AirQuality"]
|
||||
dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-Inflow", "NYCBike-Outflow"]
|
||||
|
||||
def main(model, data, debug=False):
|
||||
# 我的调试开关,不做测试就填 str(False)
|
||||
# os.environ["TRY"] = str(False)
|
||||
os.environ["TRY"] = str(False)
|
||||
|
||||
os.environ["TRY"] = str(debug)
|
||||
|
||||
for model in model_list:
|
||||
for dataset in dataset_list:
|
||||
for dataset in data:
|
||||
config_path = f"./config/{model}/{dataset}.yaml"
|
||||
# 可去这个函数里面调整统一的config项,⚠️注意调设备,epochs
|
||||
config = read_config(config_path)
|
||||
print(f"\nRunning {model} on {dataset}")
|
||||
# print(f"config: {config}")
|
||||
if os.environ.get("TRY") == "True":
|
||||
try:
|
||||
run(config)
|
||||
|
|
@ -97,4 +89,9 @@ def main(debug=False):
|
|||
|
||||
if __name__ == "__main__":
|
||||
# 调试用
|
||||
main(debug = True)
|
||||
# model_list = ["iTransformer", "PatchTST", "HI"]
|
||||
model_list = ["ASTRA_v2", "GWN", "REPST", "STAEFormer", "MTGNN"]
|
||||
# dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"]
|
||||
# dataset_list = ["AirQuality"]
|
||||
dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-InFlow", "NYCBike-OutFlow"]
|
||||
main(model_list, dataset_list, debug = True)
|
||||
|
|
@ -73,8 +73,12 @@ class Trainer:
|
|||
out = self.ts.inverse(out, b, t, n, c)
|
||||
|
||||
if os.environ.get("TRY") == "True":
|
||||
print(out.shape, label.shape)
|
||||
assert out.shape == label.shape
|
||||
if out.shape == label.shape:
|
||||
print("shape true")
|
||||
assert False
|
||||
else:
|
||||
print("shape false")
|
||||
assert False
|
||||
|
||||
loss = self.loss(out, label)
|
||||
d_out = self.scaler.inverse_transform(out)
|
||||
|
|
|
|||
|
|
@ -1,106 +1,180 @@
|
|||
import os
|
||||
import time
|
||||
import copy
|
||||
import torch
|
||||
import os, time, copy, torch
|
||||
from tqdm import tqdm
|
||||
from utils.logger import get_logger
|
||||
from utils.loss_function import all_metrics
|
||||
from tqdm import tqdm
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler=None):
|
||||
self.config, self.device, self.args = args, args["basic"]["device"], args["train"]
|
||||
self.model, self.loss, self.optimizer, self.lr_scheduler = model, loss, optimizer, lr_scheduler
|
||||
self.train_loader, self.val_loader, self.test_loader, self.scaler = train_loader, val_loader, test_loader, scaler
|
||||
|
||||
log_dir = self.args["log_dir"]
|
||||
self.best_path, self.best_test_path = [os.path.join(log_dir, f"best_{suffix}_model.pth") for suffix in ["", "test"]]
|
||||
|
||||
if not self.args["debug"]: os.makedirs(log_dir, exist_ok=True)
|
||||
self.logger = get_logger(log_dir, name=self.model.__class__.__name__, debug=self.args["debug"])
|
||||
self.logger.info(f"Experiment log path in: {log_dir}")
|
||||
def __init__(self, model, loss, optimizer,
|
||||
train_loader, val_loader, test_loader,
|
||||
scaler, args, lr_scheduler=None):
|
||||
|
||||
def train(self):
|
||||
best_model = best_test_model = None
|
||||
best_loss = best_test_loss = float("inf")
|
||||
not_improved_count = 0
|
||||
|
||||
self.logger.info("Training process started")
|
||||
|
||||
for epoch in range(1, self.args["epochs"] + 1):
|
||||
train_loss = self._run_epoch(epoch, self.train_loader, "train")
|
||||
val_loss = self._run_epoch(epoch, self.val_loader or self.test_loader, "val")
|
||||
test_loss = self._run_epoch(epoch, self.test_loader, "test")
|
||||
|
||||
if train_loss > 1e6:
|
||||
self.logger.warning("Gradient explosion detected. Ending...")
|
||||
break
|
||||
|
||||
if val_loss < best_loss:
|
||||
best_loss, not_improved_count, best_model = val_loss, 0, copy.deepcopy(self.model.state_dict())
|
||||
self.logger.info("Best validation model saved!")
|
||||
elif self.args["early_stop"] and (not_improved_count := not_improved_count + 1) == self.args["early_stop_patience"]:
|
||||
self.logger.info(f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops.")
|
||||
break
|
||||
|
||||
if test_loss < best_test_loss:
|
||||
best_test_loss, best_test_model = test_loss, copy.deepcopy(self.model.state_dict())
|
||||
|
||||
self.config = args
|
||||
self.device = args["basic"]["device"]
|
||||
self.args = args["train"]
|
||||
|
||||
torch.save(best_model, self.best_path)
|
||||
torch.save(best_test_model, self.best_test_path)
|
||||
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
||||
|
||||
for model_name, state_dict in [("best validation", best_model), ("best test", best_test_model)]:
|
||||
self.model.load_state_dict(state_dict)
|
||||
self.logger.info(f"Testing on {model_name} model")
|
||||
self._run_epoch(None, self.test_loader, "test", log_horizon=True)
|
||||
|
||||
def _run_epoch(self, epoch, dataloader, mode, log_horizon=False):
|
||||
self.model.train() if mode == "train" else self.model.eval()
|
||||
optimizer_step = mode == "train"
|
||||
|
||||
total_loss, epoch_time = 0, time.time()
|
||||
self.model = model.to(self.device)
|
||||
self.loss = loss
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader or test_loader
|
||||
self.test_loader = test_loader
|
||||
self.scaler = scaler
|
||||
|
||||
self._init_paths()
|
||||
self._init_logger()
|
||||
|
||||
# ---------------- init ----------------
|
||||
def _init_paths(self):
|
||||
d = self.args["log_dir"]
|
||||
self.best_path = os.path.join(d, "best_model.pth")
|
||||
self.best_test_path = os.path.join(d, "best_test_model.pth")
|
||||
|
||||
def _init_logger(self):
|
||||
if not self.args["debug"]:
|
||||
os.makedirs(self.args["log_dir"], exist_ok=True)
|
||||
self.logger = get_logger(
|
||||
self.args["log_dir"],
|
||||
name=self.model.__class__.__name__,
|
||||
debug=self.args["debug"],
|
||||
)
|
||||
|
||||
# ---------------- epoch ----------------
|
||||
def _run_epoch(self, epoch, loader, mode):
|
||||
is_train = mode == "train"
|
||||
self.model.train() if is_train else self.model.eval()
|
||||
|
||||
total_loss, start = 0.0, time.time()
|
||||
y_pred, y_true = [], []
|
||||
|
||||
with torch.set_grad_enabled(optimizer_step):
|
||||
for data, target in tqdm(dataloader, total=len(dataloader), desc=f"{mode.capitalize()} Epoch {epoch}" if epoch else mode):
|
||||
|
||||
with torch.set_grad_enabled(is_train):
|
||||
for data, target in tqdm(loader, desc=f"{mode} {epoch}", total=len(loader)):
|
||||
data, target = data.to(self.device), target.to(self.device)
|
||||
label = target[..., :self.args["output_dim"]]
|
||||
|
||||
output = self.model(data)
|
||||
loss = self.loss(output, label)
|
||||
d_output, d_label = self.scaler.inverse_transform(output), self.scaler.inverse_transform(label)
|
||||
d_loss = self.loss(d_output, d_label)
|
||||
|
||||
|
||||
out = self.model(data)
|
||||
|
||||
if os.environ.get("TRY") == "True":
|
||||
if out.shape == label.shape:
|
||||
print(f"shape true, out: {out.shape}, label: {label.shape}")
|
||||
assert False
|
||||
else:
|
||||
print(f"shape false, out: {out.shape}, label: {label.shape}")
|
||||
assert False
|
||||
|
||||
loss = self.loss(out, label)
|
||||
d_out = self.scaler.inverse_transform(out)
|
||||
d_lbl = self.scaler.inverse_transform(label)
|
||||
d_loss = self.loss(d_out, d_lbl)
|
||||
|
||||
total_loss += d_loss.item()
|
||||
y_pred.append(d_output.detach().cpu())
|
||||
y_true.append(d_label.detach().cpu())
|
||||
|
||||
if optimizer_step and self.optimizer:
|
||||
y_pred.append(d_out.detach().cpu())
|
||||
y_true.append(d_lbl.detach().cpu())
|
||||
|
||||
if is_train and self.optimizer:
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
if self.args["grad_norm"]: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"])
|
||||
if self.args["grad_norm"]:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.args["max_grad_norm"]
|
||||
)
|
||||
self.optimizer.step()
|
||||
|
||||
y_pred, y_true = torch.cat(y_pred, dim=0), torch.cat(y_true, dim=0)
|
||||
|
||||
if log_horizon:
|
||||
for t in range(y_true.shape[1]):
|
||||
mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], self.args["mae_thresh"], self.args["mape_thresh"])
|
||||
self.logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
|
||||
|
||||
avg_mae, avg_rmse, avg_mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"])
|
||||
|
||||
if epoch and mode:
|
||||
self.logger.info(f"Epoch #{epoch:02d}: {mode.capitalize():<5} MAE:{avg_mae:5.2f} | RMSE:{avg_rmse:5.2f} | MAPE:{avg_mape:7.4f} | Time: {time.time()-epoch_time:.2f} s")
|
||||
elif mode:
|
||||
self.logger.info(f"{mode.capitalize():<5} MAE:{avg_mae:.4f} | RMSE:{avg_rmse:.4f} | MAPE:{avg_mape:.4f}")
|
||||
|
||||
return total_loss / len(dataloader)
|
||||
|
||||
def test(self, path=None):
|
||||
if path:
|
||||
self.model.load_state_dict(torch.load(path)["state_dict"])
|
||||
self.model.to(self.device)
|
||||
y_pred = torch.cat(y_pred)
|
||||
y_true = torch.cat(y_true)
|
||||
|
||||
mae, rmse, mape = all_metrics(
|
||||
y_pred, y_true,
|
||||
self.args["mae_thresh"],
|
||||
self.args["mape_thresh"]
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
f"Epoch #{epoch:02d} {mode:<5} "
|
||||
f"MAE:{mae:5.2f} RMSE:{rmse:5.2f} "
|
||||
f"MAPE:{mape:7.4f} Time:{time.time()-start:.2f}s"
|
||||
)
|
||||
return total_loss / len(loader)
|
||||
|
||||
# ---------------- train ----------------
|
||||
def train(self):
|
||||
best, best_test = float("inf"), float("inf")
|
||||
best_w, best_test_w = None, None
|
||||
patience = 0
|
||||
|
||||
self.logger.info("Training started")
|
||||
|
||||
for epoch in range(1, self.args["epochs"] + 1):
|
||||
losses = {
|
||||
"train": self._run_epoch(epoch, self.train_loader, "train"),
|
||||
"val": self._run_epoch(epoch, self.val_loader, "val"),
|
||||
"test": self._run_epoch(epoch, self.test_loader, "test"),
|
||||
}
|
||||
|
||||
if losses["train"] > 1e6:
|
||||
self.logger.warning("Gradient explosion detected")
|
||||
break
|
||||
|
||||
if losses["val"] < best:
|
||||
best, patience = losses["val"], 0
|
||||
best_w = copy.deepcopy(self.model.state_dict())
|
||||
self.logger.info("Best validation model saved")
|
||||
else:
|
||||
patience += 1
|
||||
|
||||
if self.args["early_stop"] and patience == self.args["early_stop_patience"]:
|
||||
self.logger.info("Early stopping triggered")
|
||||
break
|
||||
|
||||
if losses["test"] < best_test:
|
||||
best_test = losses["test"]
|
||||
best_test_w = copy.deepcopy(self.model.state_dict())
|
||||
|
||||
if not self.args["debug"]:
|
||||
torch.save(best_w, self.best_path)
|
||||
torch.save(best_test_w, self.best_test_path)
|
||||
|
||||
self._final_test(best_w, best_test_w)
|
||||
|
||||
# ---------------- final test ----------------
|
||||
def _final_test(self, best_w, best_test_w):
|
||||
for name, w in [("best val", best_w), ("best test", best_test_w)]:
|
||||
self.model.load_state_dict(w)
|
||||
self.logger.info(f"Testing on {name} model")
|
||||
self.evaluate()
|
||||
|
||||
# ---------------- evaluate ----------------
|
||||
def evaluate(self):
|
||||
self.model.eval()
|
||||
y_pred, y_true = [], []
|
||||
|
||||
with torch.no_grad():
|
||||
for data, target in self.test_loader:
|
||||
data, target = data.to(self.device), target.to(self.device)
|
||||
label = target[..., :self.args["output_dim"]]
|
||||
|
||||
out = self.model(data)
|
||||
|
||||
y_pred.append(out.cpu())
|
||||
y_true.append(label.cpu())
|
||||
|
||||
d_pred = self.scaler.inverse_transform(torch.cat(y_pred))
|
||||
d_true = self.scaler.inverse_transform(torch.cat(y_true))
|
||||
|
||||
for t in range(d_true.shape[1]):
|
||||
mae, rmse, mape = all_metrics(
|
||||
d_pred[:, t], d_true[:, t],
|
||||
self.args["mae_thresh"],
|
||||
self.args["mape_thresh"]
|
||||
)
|
||||
self.logger.info(
|
||||
f"Horizon {t+1:02d} MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}"
|
||||
)
|
||||
|
||||
self._run_epoch(None, self.test_loader, "test", log_horizon=True)
|
||||
avg_mae, avg_rmse, avg_mape = all_metrics(d_pred, d_true, self.args["mae_thresh"], self.args["mape_thresh"])
|
||||
self.logger.info(
|
||||
f"AVG MAE:{avg_mae:.4f} AVG RMSE:{avg_rmse:.4f} AVG MAPE:{avg_mape:.4f}"
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue