TrafficWheel/trainer/InformerTrainer.py

251 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import os
import time
import copy
import torch
from utils.logger import get_logger
from utils.loss_function import all_metrics
from tqdm import tqdm
class InformerTrainer:
"""Informer模型训练器负责整个训练流程的管理支持多输入模型"""
def __init__(self, model, loss, optimizer,
train_loader, val_loader, test_loader, scaler,
args, lr_scheduler=None,):
# 设备和基本参数
self.config = args
self.device = args["basic"]["device"]
train_args = args["train"]
# 模型和训练相关组件
self.model = model
self.loss = loss
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
# 数据加载器
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
# 数据处理工具
self.scaler = scaler
self.args = train_args
# 初始化路径、日志和统计
self._initialize_paths(train_args)
self._initialize_logger(train_args)
def _initialize_paths(self, args):
"""初始化模型保存路径"""
self.best_path = os.path.join(args["log_dir"], "best_model.pth")
self.best_test_path = os.path.join(args["log_dir"], "best_test_model.pth")
self.loss_figure_path = os.path.join(args["log_dir"], "loss.png")
def _initialize_logger(self, args):
"""初始化日志记录器"""
if not os.path.isdir(args["log_dir"]) and not args["debug"]:
os.makedirs(args["log_dir"], exist_ok=True)
self.logger = get_logger(args["log_dir"], name=self.model.__class__.__name__, debug=args["debug"])
self.logger.info(f"Experiment log path in: {args['log_dir']}")
def _run_epoch(self, epoch, dataloader, mode):
"""运行一个训练/验证/测试epoch支持多输入模型"""
# 设置模型模式和是否进行优化
if mode == "train": self.model.train(); optimizer_step = True
else: self.model.eval(); optimizer_step = False
# 初始化变量
total_loss = 0
epoch_time = time.time()
y_pred, y_true = [], []
# 训练/验证循环
with torch.set_grad_enabled(optimizer_step):
progress_bar = tqdm(
enumerate(dataloader),
total=len(dataloader),
desc=f"{mode.capitalize()} Epoch {epoch}"
)
for _, (x, y, x_mark, y_mark) in progress_bar:
# 转移数据
x = x.to(self.device)
y = y[:, -self.args['pred_len']:, :self.args["output_dim"]].to(self.device)
x_mark = x_mark.to(self.device)
y_mark = y_mark.to(self.device)
# [256, 24, 6]
dec_inp = torch.zeros_like(y[:, -self.args['pred_len']:, :]).float()
# [256, 48(pred+label), 6]
dec_inp = torch.cat([y[:, :self.args['label_len'], :], dec_inp], dim=1).float().to(self.device)
# 计算loss和反归一化loss
output = self.model(x, x_mark, dec_inp, y_mark)
if os.environ.get("TRY") == "True":
print(f"[{'' if output.shape == y.shape else ''}]: output: {output.shape}, label: {y.shape}")
assert False
loss = self.loss(output, y)
d_output = self.scaler.inverse_transform(output)
d_label = self.scaler.inverse_transform(y)
d_loss = self.loss(d_output, d_label)
# 累积损失和预测结果
total_loss += d_loss.item()
y_pred.append(d_output.detach().cpu())
y_true.append(d_label.detach().cpu())
# 反向传播和优化(仅在训练模式)
if optimizer_step and self.optimizer is not None:
self.optimizer.zero_grad()
loss.backward()
# 梯度裁剪(如果需要)
if self.args["grad_norm"]:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"])
self.optimizer.step()
# 更新进度条
progress_bar.set_postfix(loss=d_loss.item())
# 合并所有批次的预测结果
y_pred = torch.cat(y_pred, dim=0)
y_true = torch.cat(y_true, dim=0)
# 计算损失并记录指标
avg_loss = total_loss / len(dataloader)
mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"])
self.logger.info(
f"Epoch #{epoch:02d}: {mode.capitalize():<5} "
f"MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s"
)
return avg_loss
def train_epoch(self, epoch):
return self._run_epoch(epoch, self.train_loader, "train")
def val_epoch(self, epoch):
return self._run_epoch(epoch, self.val_loader or self.test_loader, "val")
def test_epoch(self, epoch):
return self._run_epoch(epoch, self.test_loader, "test")
def train(self):
# 初始化记录
best_model, best_test_model = None, None
best_loss, best_test_loss = float("inf"), float("inf")
not_improved_count = 0
# 开始训练
self.logger.info("Training process started")
# 训练循环
for epoch in range(1, self.args["epochs"] + 1):
# 训练、验证和测试一个epoch
train_epoch_loss = self.train_epoch(epoch)
val_epoch_loss = self.val_epoch(epoch)
test_epoch_loss = self.test_epoch(epoch)
# 检查梯度爆炸
if train_epoch_loss > 1e6:
self.logger.warning("Gradient explosion detected. Ending...")
break
# 更新最佳验证模型
if val_epoch_loss < best_loss:
best_loss = val_epoch_loss
not_improved_count = 0
best_model = copy.deepcopy(self.model.state_dict())
self.logger.info("Best validation model saved!")
else:
not_improved_count += 1
# 早停
if self._should_early_stop(not_improved_count):
break
# 更新最佳测试模型
if test_epoch_loss < best_test_loss:
best_test_loss = test_epoch_loss
best_test_model = copy.deepcopy(self.model.state_dict())
# 保存最佳模型
if not self.args["debug"]:
self._save_best_models(best_model, best_test_model)
# 最终评估
self._finalize_training(best_model, best_test_model)
def _should_early_stop(self, not_improved_count):
"""检查是否满足早停条件"""
if (
self.args["early_stop"]
and not_improved_count == self.args["early_stop_patience"]
):
self.logger.info(
f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops."
)
return True
return False
def _save_best_models(self, best_model, best_test_model):
"""保存最佳模型到文件"""
torch.save(best_model, self.best_path)
torch.save(best_test_model, self.best_test_path)
self.logger.info(
f"Best models saved at {self.best_path} and {self.best_test_path}"
)
def _log_model_params(self):
"""输出模型可训练参数数量"""
total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
def _finalize_training(self, best_model, best_test_model):
self.model.load_state_dict(best_model)
self.logger.info("Testing on best validation model")
self.test(self.model, self.args, self.test_loader, self.scaler, self.logger)
self.model.load_state_dict(best_test_model)
self.logger.info("Testing on best test model")
self.test(self.model, self.args, self.test_loader, self.scaler, self.logger)
@staticmethod
def test(model, args, data_loader, scaler, logger, path=None):
"""对模型进行评估并输出性能指标,支持多输入模型"""
device = args["device"]
if path:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["state_dict"])
model.to(device)
# 设置为评估模式
model.eval()
# 收集预测和真实标签
y_pred, y_true = [], []
pred_len = args['pred_len']
label_len = args['label_len']
output_dim = args['output_dim']
# 不计算梯度的情况下进行预测
with torch.no_grad():
for _, (x, y, x_mark, y_mark) in enumerate(data_loader):
# 转移数据
x = x.to(device)
y = y[:, -pred_len:, :output_dim].to(device)
x_mark = x_mark.to(device)
y_mark = y_mark.to(device)
# 生成dec_inp
dec_inp = torch.zeros_like(y[:, -pred_len:, :]).float()
dec_inp = torch.cat([y[:, :label_len, :], dec_inp], dim=1).float().to(device)
output = model(x, x_mark, dec_inp, y_mark)
y_pred.append(output.detach().cpu())
y_true.append(y.detach().cpu())
d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0))
mae_thresh = args["mae_thresh"]
mape_thresh = args["mape_thresh"]
# 计算并记录每个时间步的指标
for t in range(d_y_true.shape[1]):
mae, rmse, mape = all_metrics(
d_y_pred[:, t, ...],
d_y_true[:, t, ...],
mae_thresh,
mape_thresh,
)
logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
# 计算并记录平均指标
mae, rmse, mape = all_metrics(d_y_pred, d_y_true, mae_thresh, mape_thresh)
logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
@staticmethod
def _compute_sampling_threshold(global_step, k):
return k / (k + math.exp(global_step / k))