refactor: 重构数据加载器和训练器代码,优化代码结构和可读性

重构数据加载器模块,使用字典映射替代switch-case结构
简化训练器逻辑,合并重复代码,提高可维护性
优化日志时间格式,缩短显示长度
调整训练配置,减少默认epoch数并启用GPU训练
统一数据加载方式,提取公共方法减少重复代码
This commit is contained in:
czzhangheng 2025-12-15 01:38:47 +08:00
parent 97743dfd05
commit 3095b7435b
7 changed files with 318 additions and 685 deletions

View File

@ -2,95 +2,80 @@ import os
import numpy as np
import h5py
def load_st_dataset(config):
dataset = config["basic"]["dataset"]
# sample = config["data"]["sample"]
# output B, N, D
match dataset:
case "BeijingAirQuality":
data_path = os.path.join("./data/BeijingAirQuality/data.dat")
data = np.memmap(data_path, dtype=np.float32, mode='r')
L, N, C = 36000, 7, 3
data = data.reshape(L, N, C)
case "AirQuality":
data_path = os.path.join("./data/AirQuality/data.dat")
data = np.memmap(data_path, dtype=np.float32, mode='r')
L, N, C = 8701,35,6
data = data.reshape(L, N, C)
case "PEMS-BAY":
data_path = os.path.join("./data/PEMS-BAY/pems-bay.h5")
with h5py.File(data_path, 'r') as f:
data = f['speed']['block0_values'][:]
case "METR-LA":
data_path = os.path.join("./data/METR-LA/METR-LA.h5")
with h5py.File(data_path, 'r') as f:
data = f['df']['block0_values'][:]
case "SolarEnergy":
data_path = os.path.join("./data/SolarEnergy/SolarEnergy.csv")
data = np.loadtxt(data_path, delimiter=",")
case "PEMSD3":
data_path = os.path.join("./data/PEMS03/PEMS03.npz")
data = np.load(data_path)["data"][:, :, 0]
case "PEMSD4":
data_path = os.path.join("./data/PEMS04/PEMS04.npz")
data = np.load(data_path)["data"][:, :, 0]
case "PEMSD7":
data_path = os.path.join("./data/PEMS07/PEMS07.npz")
data = np.load(data_path)["data"][:, :, 0]
case "PEMSD8":
data_path = os.path.join("./data/PEMS08/PEMS08.npz")
data = np.load(data_path)["data"][:, :, 0]
case "PEMSD7(L)":
data_path = os.path.join("./data/PEMS07(L)/PEMS07L.npz")
data = np.load(data_path)["data"][:, :, 0]
case "PEMSD7(M)":
data_path = os.path.join("./data/PEMS07(M)/V_228.csv")
data = np.genfromtxt(data_path, delimiter=",")
case "BJ":
data_path = os.path.join("./data/BJ/BJ500.csv")
data = np.genfromtxt(data_path, delimiter=",", skip_header=1)
case "Hainan":
data_path = os.path.join("./data/Hainan/Hainan.npz")
data = np.load(data_path)["data"][:, :, 0]
case "SD":
data_path = os.path.join("./data/SD/data.npz")
data = np.load(data_path)["data"][:, :, 0].astype(np.float32)
case "BJTaxi-InFlow":
data = read_BeijingTaxi()[:, :, 0:1].astype(np.float32)
case "BJTaxi-OutFlow":
data = read_BeijingTaxi()[:, :, 1:2].astype(np.float32)
case "NYCBike-InFlow":
data_path = os.path.join("./data/NYCBike/NYC16x8.h5")
with h5py.File(data_path, 'r') as f:
data = f['data'][:].astype(np.float32)
data = data.transpose(0,2,3,1).reshape(-1, 16*8, 2)
data = data[:, :, 0:1]
case "NYCBike-OutFlow":
data_path = os.path.join("./data/NYCBike/NYC16x8.h5")
with h5py.File(data_path, 'r') as f:
data = f['data'][:].astype(np.float32)
data = data.transpose(0,2,3,1).reshape(-1, 16*8, 2)
data = data[:, :, 1:2]
case _:
loaders = {
"BeijingAirQuality": lambda: _memmap("./data/BeijingAirQuality/data.dat", 36000, 7, 3),
"AirQuality": lambda: _memmap("./data/AirQuality/data.dat", 8701, 35, 6),
"PEMS-BAY": lambda: _h5("./data/PEMS-BAY/pems-bay.h5", ("speed", "block0_values")),
"METR-LA": lambda: _h5("./data/METR-LA/METR-LA.h5", ("df", "block0_values")),
"SolarEnergy": lambda: np.loadtxt("./data/SolarEnergy/SolarEnergy.csv", delimiter=","),
"PEMSD3": lambda: _npz("./data/PEMS03/PEMS03.npz"),
"PEMSD4": lambda: _npz("./data/PEMS04/PEMS04.npz"),
"PEMSD7": lambda: _npz("./data/PEMS07/PEMS07.npz"),
"PEMSD8": lambda: _npz("./data/PEMS08/PEMS08.npz"),
"PEMSD7(L)": lambda: _npz("./data/PEMS07(L)/PEMS07L.npz"),
"PEMSD7(M)": lambda: np.genfromtxt("./data/PEMS07(M)/V_228.csv", delimiter=","),
"BJ": lambda: np.genfromtxt("./data/BJ/BJ500.csv", delimiter=",", skip_header=1),
"Hainan": lambda: _npz("./data/Hainan/Hainan.npz"),
"SD": lambda: _npz("./data/SD/data.npz", cast=True),
"BJTaxi-InFlow": lambda: read_BeijingTaxi()[:, :, 0:1].astype(np.float32),
"BJTaxi-OutFlow": lambda: read_BeijingTaxi()[:, :, 1:2].astype(np.float32),
"NYCBike-InFlow": lambda: _nyc_bike(0),
"NYCBike-OutFlow": lambda: _nyc_bike(1),
}
if dataset not in loaders:
raise ValueError(f"Unsupported dataset: {dataset}")
# Ensure data shape compatibility
if len(data.shape) == 2:
data = np.expand_dims(data, axis=-1)
data = loaders[dataset]()
print("加载 %s 数据集中... " % dataset)
# return data[::sample]
if data.ndim == 2:
data = data[..., None]
print(f"加载 {dataset} 数据集中... ")
return data
# ---------------- helpers ----------------
def _memmap(path, L, N, C):
data = np.memmap(path, dtype=np.float32, mode="r")
return data.reshape(L, N, C)
def _h5(path, keys):
with h5py.File(path, "r") as f:
return f[keys[0]][keys[1]][:]
def _npz(path, cast=False):
data = np.load(path)["data"][:, :, 0]
return data.astype(np.float32) if cast else data
def _nyc_bike(channel):
with h5py.File("./data/NYCBike/NYC16x8.h5", "r") as f:
data = f["data"][:].astype(np.float32)
data = data.transpose(0, 2, 3, 1).reshape(-1, 16 * 8, 2)
return data[:, :, channel:channel + 1]
def read_BeijingTaxi():
files = ["TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy",
"TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy"]
all_data = []
for file in files:
data_path = os.path.join(f"./data/BeijingTaxi/{file}")
data = np.load(data_path)
all_data.append(data)
all_data = np.concatenate(all_data, axis=0)
time_num = all_data.shape[0]
all_data = all_data.transpose(0, 2, 3, 1).reshape(time_num, 32*32, 2)
return all_data
files = [
"TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy",
"TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy",
]
data = np.concatenate(
[np.load(f"./data/BeijingTaxi/{f}") for f in files], axis=0
)
T = data.shape[0]
return data.transpose(0, 2, 3, 1).reshape(T, 32 * 32, 2)

View File

@ -8,21 +8,12 @@ from dataloader.Informer_loader import get_dataloader as Informer_loader
def get_dataloader(config, normalizer, single):
TS_model = ["iTransformer", "HI", "PatchTST"]
model_name = config["basic"]["model"]
# if model_name == "Informer":
# return Informer_loader(config, normalizer, single)
# elif model_name in TS_model:
# return TS_loader(config, normalizer, single)
# else :
match model_name:
case "STGNCDE":
return cde_loader(config, normalizer, single)
case "STGNRDE":
return nrde_loader(config, normalizer, single)
case "DCRNN":
return DCRNN_loader(config, normalizer, single)
case "EXP":
return EXP_loader(config, normalizer, single)
case _:
return normal_loader(config, normalizer, single)
loader_map = {
"STGNCDE": cde_loader,
"STGNRDE": nrde_loader,
"DCRNN": DCRNN_loader,
"EXP": EXP_loader,
}
return loader_map.get(config["basic"]["model"], normal_loader)(
config, normalizer, single
)

View File

@ -11,9 +11,9 @@ def read_config(config_path):
config = yaml.safe_load(file)
# 全局配置
device = "cpu" # 指定设备为cuda:0
device = "cuda:0" # 指定设备为cuda:0
seed = 2023 # 随机种子
epochs = 120
epochs = 1
# 拷贝项
config["basic"]["device"] = device
@ -65,8 +65,8 @@ def main(debug=False):
model_list = ["iTransformer"]
# 指定数据集
# dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"]
dataset_list = ["AirQuality"]
# dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-Inflow", "NYCBike-Outflow"]
# dataset_list = ["AirQuality"]
dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-Inflow", "NYCBike-Outflow"]
# 我的调试开关,不做测试就填 str(False)
# os.environ["TRY"] = str(False)

View File

@ -1,296 +1,195 @@
import math
import os
import time
import copy
import torch
import os, time, copy, torch
from tqdm import tqdm
from utils.logger import get_logger
from utils.loss_function import all_metrics
from tqdm import tqdm
class TSWrapper:
def __init__(self, args):
self.b = args['train']['batch_size']
self.t = args['data']['lag']
self.n = args['data']['num_nodes']
self.c = args['data']['input_dim']
def transpose(self, x : torch.Tensor):
def forward(self, x):
# [b, t, n, c] -> [b*n, t, c]
self.b = x.shape[0]
x = x[..., :-2]
x = x.permute(0, 2, 1, 3)
x = x.reshape(self.b*self.n, self.t, self.c)
return x
b, t, n, c = x.shape
x = x[..., :-2].permute(0, 2, 1, 3).reshape(b * n, t, c-2)
return x, b, t, n, c
def inv_transpose(self, x : torch.Tensor):
x = x.reshape(self.b, self.n, self.t, self.c)
x = x.permute(0, 2, 1, 3)
return x
def inverse(self, x, b, t, n, c):
return x.reshape(b, n, t, c-2).permute(0, 2, 1, 3)
class Trainer:
"""模型训练器,负责整个训练流程的管理"""
def __init__(self, model, loss, optimizer,
train_loader, val_loader, test_loader, scaler,
args, lr_scheduler=None,):
# 设备和基本参数
train_loader, val_loader, test_loader,
scaler, args, lr_scheduler=None):
self.config = args
self.device = args["basic"]["device"]
train_args = args["train"]
# 模型和训练相关组件
self.model = model
self.args = args["train"]
self.model = model.to(self.device)
self.loss = loss
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
# 数据加载器
self.train_loader = train_loader
self.val_loader = val_loader
self.val_loader = val_loader or test_loader
self.test_loader = test_loader
# 数据处理工具
self.scaler = scaler
self.args = train_args
self.ts_wrapper = TSWrapper(args)
# 初始化路径、日志和统计
self._initialize_paths(train_args)
self._initialize_logger(train_args)
def _initialize_paths(self, args):
"""初始化模型保存路径"""
self.best_path = os.path.join(args["log_dir"], "best_model.pth")
self.best_test_path = os.path.join(args["log_dir"], "best_test_model.pth")
self.loss_figure_path = os.path.join(args["log_dir"], "loss.png")
self.ts = TSWrapper(args)
self._init_paths()
self._init_logger()
def _initialize_logger(self, args):
"""初始化日志记录器"""
if not os.path.isdir(args["log_dir"]) and not args["debug"]:
os.makedirs(args["log_dir"], exist_ok=True)
self.logger = get_logger(args["log_dir"], name=self.model.__class__.__name__, debug=args["debug"])
self.logger.info(f"Experiment log path in: {args['log_dir']}")
# ---------------- init ----------------
def _init_paths(self):
d = self.args["log_dir"]
self.best_path = os.path.join(d, "best_model.pth")
self.best_test_path = os.path.join(d, "best_test_model.pth")
def _run_epoch(self, epoch, dataloader, mode):
"""运行一个训练/验证/测试epoch"""
# 设置模型模式和是否进行优化
if mode == "train": self.model.train(); optimizer_step = True
else: self.model.eval(); optimizer_step = False
def _init_logger(self):
if not self.args["debug"]:
os.makedirs(self.args["log_dir"], exist_ok=True)
self.logger = get_logger(
self.args["log_dir"],
name=self.model.__class__.__name__,
debug=self.args["debug"],
)
# 初始化变量
total_loss = 0
epoch_time = time.time()
# ---------------- epoch ----------------
def _run_epoch(self, epoch, loader, mode):
is_train = mode == "train"
self.model.train() if is_train else self.model.eval()
total_loss, start = 0.0, time.time()
y_pred, y_true = [], []
# 训练/验证循环
with torch.set_grad_enabled(optimizer_step):
progress_bar = tqdm(
enumerate(dataloader),
total=len(dataloader),
desc=f"{mode.capitalize()} Epoch {epoch}"
)
for _, (data, target) in progress_bar:
# 转移数据
data = data.to(self.device)
target = target.to(self.device)
label = target[..., : self.args["output_dim"]]
# 转换为 [b*n, t, c]
data = self.ts_wrapper.transpose(data)
# 计算loss和反归一化loss
output = self.model(data)
# 转换回[b, t, n, c]
output = self.ts_wrapper.inv_transpose(output)
# 我的调试开关
with torch.set_grad_enabled(is_train):
for data, target in tqdm(loader, desc=f"{mode} {epoch}", total=len(loader)):
data, target = data.to(self.device), target.to(self.device)
label = target[..., :self.args["output_dim"]]
x, b, t, n, c = self.ts.forward(data)
out = self.model(x)
out = self.ts.inverse(out, b, t, n, c)
if os.environ.get("TRY") == "True":
print(f"[{'' if output.shape == label.shape else ''}]: output: {output.shape}, label: {label.shape}")
assert False
loss = self.loss(output, label)
d_output = self.scaler.inverse_transform(output)
d_label = self.scaler.inverse_transform(label)
d_loss = self.loss(d_output, d_label)
# 累积损失和预测结果
print(out.shape, label.shape)
assert out.shape == label.shape
loss = self.loss(out, label)
d_out = self.scaler.inverse_transform(out)
d_lbl = self.scaler.inverse_transform(label)
d_loss = self.loss(d_out, d_lbl)
total_loss += d_loss.item()
y_pred.append(d_output.detach().cpu())
y_true.append(d_label.detach().cpu())
# 反向传播和优化(仅在训练模式)
if optimizer_step and self.optimizer is not None:
y_pred.append(d_out.detach().cpu())
y_true.append(d_lbl.detach().cpu())
if is_train and self.optimizer:
self.optimizer.zero_grad()
loss.backward()
# 梯度裁剪(如果需要)
if self.args["grad_norm"]:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"])
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.args["max_grad_norm"]
)
self.optimizer.step()
# 更新进度条
progress_bar.set_postfix(loss=d_loss.item())
# 合并所有批次的预测结果
y_pred = torch.cat(y_pred, dim=0)
y_true = torch.cat(y_true, dim=0)
# 计算损失并记录指标
avg_loss = total_loss / len(dataloader)
mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"])
self.logger.info(
f"Epoch #{epoch:02d}: {mode.capitalize():<5} "
f"MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s"
y_pred = torch.cat(y_pred)
y_true = torch.cat(y_true)
mae, rmse, mape = all_metrics(
y_pred, y_true,
self.args["mae_thresh"],
self.args["mape_thresh"]
)
return avg_loss
def train_epoch(self, epoch):
return self._run_epoch(epoch, self.train_loader, "train")
def val_epoch(self, epoch):
return self._run_epoch(epoch, self.val_loader or self.test_loader, "val")
def test_epoch(self, epoch):
return self._run_epoch(epoch, self.test_loader, "test")
self.logger.info(
f"Epoch #{epoch:02d} {mode:<5} "
f"MAE:{mae:5.2f} RMSE:{rmse:5.2f} "
f"MAPE:{mape:7.4f} Time:{time.time()-start:.2f}s"
)
return total_loss / len(loader)
# ---------------- train ----------------
def train(self):
# 初始化记录
best_model, best_test_model = None, None
best_loss, best_test_loss = float("inf"), float("inf")
not_improved_count = 0
# 开始训练
self.logger.info("Training process started")
# 训练循环
best, best_test = float("inf"), float("inf")
best_w, best_test_w = None, None
patience = 0
self.logger.info("Training started")
for epoch in range(1, self.args["epochs"] + 1):
# 训练、验证和测试一个epoch
train_epoch_loss = self.train_epoch(epoch)
val_epoch_loss = self.val_epoch(epoch)
test_epoch_loss = self.test_epoch(epoch)
# 检查梯度爆炸
if train_epoch_loss > 1e6:
self.logger.warning("Gradient explosion detected. Ending...")
losses = {
"train": self._run_epoch(epoch, self.train_loader, "train"),
"val": self._run_epoch(epoch, self.val_loader, "val"),
"test": self._run_epoch(epoch, self.test_loader, "test"),
}
if losses["train"] > 1e6:
self.logger.warning("Gradient explosion detected")
break
# 更新最佳验证模型
if val_epoch_loss < best_loss:
best_loss = val_epoch_loss
not_improved_count = 0
best_model = copy.deepcopy(self.model.state_dict())
self.logger.info("Best validation model saved!")
if losses["val"] < best:
best, patience = losses["val"], 0
best_w = copy.deepcopy(self.model.state_dict())
self.logger.info("Best validation model saved")
else:
not_improved_count += 1
# 早停
if self._should_early_stop(not_improved_count):
patience += 1
if self.args["early_stop"] and patience == self.args["early_stop_patience"]:
self.logger.info("Early stopping triggered")
break
# 更新最佳测试模型
if test_epoch_loss < best_test_loss:
best_test_loss = test_epoch_loss
best_test_model = copy.deepcopy(self.model.state_dict())
# 保存最佳模型
if losses["test"] < best_test:
best_test = losses["test"]
best_test_w = copy.deepcopy(self.model.state_dict())
if not self.args["debug"]:
self._save_best_models(best_model, best_test_model)
# 最终评估
self._finalize_training(best_model, best_test_model)
torch.save(best_w, self.best_path)
torch.save(best_test_w, self.best_test_path)
def _should_early_stop(self, not_improved_count):
"""检查是否满足早停条件"""
if (
self.args["early_stop"]
and not_improved_count == self.args["early_stop_patience"]
):
self.logger.info(
f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops."
)
return True
return False
self._final_test(best_w, best_test_w)
def _save_best_models(self, best_model, best_test_model):
"""保存最佳模型到文件"""
torch.save(best_model, self.best_path)
torch.save(best_test_model, self.best_test_path)
self.logger.info(
f"Best models saved at {self.best_path} and {self.best_test_path}"
)
# ---------------- final test ----------------
def _final_test(self, best_w, best_test_w):
for name, w in [("best val", best_w), ("best test", best_test_w)]:
self.model.load_state_dict(w)
self.logger.info(f"Testing on {name} model")
self.evaluate()
def _log_model_params(self):
"""输出模型可训练参数数量"""
total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
def _finalize_training(self, best_model, best_test_model):
self.model.load_state_dict(best_model)
self.logger.info("Testing on best validation model")
self.test(self.model, self.config, self.test_loader, self.scaler, self.logger)
self.model.load_state_dict(best_test_model)
self.logger.info("Testing on best test model")
self.test(self.model, self.config, self.test_loader, self.scaler, self.logger)
@staticmethod
def test(model, args, data_loader, scaler, logger, path=None):
"""对模型进行评估并输出性能指标"""
# 确定设备信息
device = None
output_dim = None
# 处理不同的参数格式
if isinstance(args, dict):
if "basic" in args:
# 完整配置情况
device = args["basic"]["device"]
output_dim = args["train"]["output_dim"]
else:
# 只有train_args情况从模型获取设备
device = next(model.parameters()).device
output_dim = args["output_dim"]
else:
raise ValueError(f"Unsupported args type: {type(args)}")
# 加载模型检查点(如果提供了路径)
if path:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["state_dict"])
model.to(device)
# 设置为评估模式
model.eval()
# 收集预测和真实标签
# ---------------- evaluate ----------------
def evaluate(self):
self.model.eval()
y_pred, y_true = [], []
# 不计算梯度的情况下进行预测
with torch.no_grad():
for data, target in data_loader:
# 将数据和标签移动到指定设备
data = data.to(device)
target = target.to(device)
for data, target in self.test_loader:
data, target = data.to(self.device), target.to(self.device)
label = target[..., :self.args["output_dim"]]
data = data[..., :-2]
b, t, n, c = data.shape
data = data.permute(0, 2, 1, 3)
data = data.reshape(b*n, t, c)
label = target[..., : output_dim]
output = model(data)
output = output.reshape(b, n, t, c)
output = output.permute(0, 2, 1, 3)
x, b, t, n, c = self.ts.forward(data)
out = self.model(x)
out = self.ts.inverse(out, b, t, n, c)
y_pred.append(output.detach().cpu())
y_true.append(label.detach().cpu())
y_pred.append(out.cpu())
y_true.append(label.cpu())
d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0))
d_pred = self.scaler.inverse_transform(torch.cat(y_pred))
d_true = self.scaler.inverse_transform(torch.cat(y_true))
# 获取metrics参数
if "basic" in args:
# 完整配置情况
mae_thresh = args["train"]["mae_thresh"]
mape_thresh = args["train"]["mape_thresh"]
else:
# 只有train_args情况
mae_thresh = args["mae_thresh"]
mape_thresh = args["mape_thresh"]
# 计算并记录每个时间步的指标
for t in range(d_y_true.shape[1]):
for t in range(d_true.shape[1]):
mae, rmse, mape = all_metrics(
d_y_pred[:, t, ...],
d_y_true[:, t, ...],
mae_thresh,
mape_thresh,
d_pred[:, t], d_true[:, t],
self.args["mae_thresh"],
self.args["mape_thresh"]
)
self.logger.info(
f"Horizon {t+1:02d} MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}"
)
logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
# 计算并记录平均指标
mae, rmse, mape = all_metrics(d_y_pred, d_y_true, mae_thresh, mape_thresh)
logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
avg_mae, avg_rmse, avg_mape = all_metrics(d_pred, d_true, self.args["mae_thresh"], self.args["mape_thresh"])
self.logger.info(
f"AVG MAE:{avg_mae:.4f} AVG RMSE:{avg_rmse:.4f} AVG MAPE:{avg_mape:.4f}"
)
@staticmethod
def _compute_sampling_threshold(global_step, k):
return k / (k + math.exp(global_step / k))

View File

@ -1,4 +1,3 @@
import math
import os
import time
import copy
@ -8,240 +7,100 @@ from utils.loss_function import all_metrics
from tqdm import tqdm
class Trainer:
"""模型训练器,负责整个训练流程的管理"""
def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler=None):
# 设备和基本参数
self.config = args
self.device = args["basic"]["device"]
self.args = args["train"]
# 模型和训练相关组件
self.config, self.device, self.args = args, args["basic"]["device"], args["train"]
self.model, self.loss, self.optimizer, self.lr_scheduler = model, loss, optimizer, lr_scheduler
self.train_loader, self.val_loader, self.test_loader, self.scaler = train_loader, val_loader, test_loader, scaler
# 数据加载器
self.train_loader, self.val_loader, self.test_loader = train_loader, val_loader, test_loader
log_dir = self.args["log_dir"]
self.best_path, self.best_test_path = [os.path.join(log_dir, f"best_{suffix}_model.pth") for suffix in ["", "test"]]
# 数据处理工具
self.scaler = scaler
# 初始化路径、日志和统计
self._initialize_paths(self.args)
self._initialize_logger(self.args)
def _initialize_paths(self, args):
"""初始化模型保存路径"""
log_dir = args["log_dir"]
self.best_path = os.path.join(log_dir, "best_model.pth")
self.best_test_path = os.path.join(log_dir, "best_test_model.pth")
self.loss_figure_path = os.path.join(log_dir, "loss.png")
def _initialize_logger(self, args):
"""初始化日志记录器"""
log_dir = args["log_dir"]
if not args["debug"]:
os.makedirs(log_dir, exist_ok=True)
self.logger = get_logger(log_dir, name=self.model.__class__.__name__, debug=args["debug"])
if not self.args["debug"]: os.makedirs(log_dir, exist_ok=True)
self.logger = get_logger(log_dir, name=self.model.__class__.__name__, debug=self.args["debug"])
self.logger.info(f"Experiment log path in: {log_dir}")
def _run_epoch(self, epoch, dataloader, mode):
"""运行一个训练/验证/测试epoch"""
# 设置模型模式和是否进行优化
self.model.train() if mode == "train" else self.model.eval()
optimizer_step = mode == "train"
# 初始化变量
total_loss = 0
epoch_time = time.time()
y_pred, y_true = [], []
# 训练/验证循环
with torch.set_grad_enabled(optimizer_step):
progress_bar = tqdm(
dataloader,
total=len(dataloader),
desc=f"{mode.capitalize()} Epoch {epoch}"
)
for data, target in progress_bar:
# 转移数据并提取标签
data, target = data.to(self.device), target.to(self.device)
label = target[..., : self.args["output_dim"]]
# 计算输出
output = self.model(data)
# 我的调试开关
if os.environ.get("TRY") == "True":
status = '' if output.shape == label.shape else ''
print(f"[{status}]: output: {output.shape}, label: {label.shape}")
assert False
# 计算损失
loss = self.loss(output, label)
d_output = self.scaler.inverse_transform(output)
d_label = self.scaler.inverse_transform(label)
d_loss = self.loss(d_output, d_label)
# 累积损失和预测结果
total_loss += d_loss.item()
y_pred.append(d_output.detach().cpu())
y_true.append(d_label.detach().cpu())
# 反向传播和优化(仅在训练模式)
if optimizer_step and self.optimizer is not None:
self.optimizer.zero_grad()
loss.backward()
# 梯度裁剪(如果需要)
if self.args["grad_norm"]:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"])
self.optimizer.step()
# 更新进度条
progress_bar.set_postfix(loss=d_loss.item())
# 合并所有批次的预测结果
y_pred, y_true = torch.cat(y_pred, dim=0), torch.cat(y_true, dim=0)
# 计算损失并记录指标
avg_loss = total_loss / len(dataloader)
mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"])
self.logger.info(
f"Epoch #{epoch:02d}: {mode.capitalize():<5} "
f"MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s"
)
return avg_loss
def train(self):
# 初始化记录
best_model = best_test_model = None
best_loss = best_test_loss = float("inf")
not_improved_count = 0
# 开始训练
self.logger.info("Training process started")
# 训练循环
for epoch in range(1, self.args["epochs"] + 1):
# 训练、验证和测试一个epoch
train_epoch_loss = self._run_epoch(epoch, self.train_loader, "train")
val_epoch_loss = self._run_epoch(epoch, self.val_loader or self.test_loader, "val")
test_epoch_loss = self._run_epoch(epoch, self.test_loader, "test")
train_loss = self._run_epoch(epoch, self.train_loader, "train")
val_loss = self._run_epoch(epoch, self.val_loader or self.test_loader, "val")
test_loss = self._run_epoch(epoch, self.test_loader, "test")
# 检查梯度爆炸
if train_epoch_loss > 1e6:
if train_loss > 1e6:
self.logger.warning("Gradient explosion detected. Ending...")
break
# 更新最佳验证模型
if val_epoch_loss < best_loss:
best_loss, not_improved_count = val_epoch_loss, 0
best_model = copy.deepcopy(self.model.state_dict())
if val_loss < best_loss:
best_loss, not_improved_count, best_model = val_loss, 0, copy.deepcopy(self.model.state_dict())
self.logger.info("Best validation model saved!")
else:
not_improved_count += 1
# 早停检查
if self._should_early_stop(not_improved_count):
elif self.args["early_stop"] and (not_improved_count := not_improved_count + 1) == self.args["early_stop_patience"]:
self.logger.info(f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops.")
break
# 更新最佳测试模型
if test_epoch_loss < best_test_loss:
best_test_loss = test_epoch_loss
best_test_model = copy.deepcopy(self.model.state_dict())
if test_loss < best_test_loss:
best_test_loss, best_test_model = test_loss, copy.deepcopy(self.model.state_dict())
# 保存最佳模型
if not self.args["debug"]:
self._save_best_models(best_model, best_test_model)
# 最终评估
self._finalize_training(best_model, best_test_model)
def _should_early_stop(self, not_improved_count):
"""检查是否满足早停条件"""
if self.args["early_stop"] and not_improved_count == self.args["early_stop_patience"]:
self.logger.info(
f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops."
)
return True
return False
def _save_best_models(self, best_model, best_test_model):
"""保存最佳模型到文件"""
torch.save(best_model, self.best_path)
torch.save(best_test_model, self.best_test_path)
self.logger.info(
f"Best models saved at {self.best_path} and {self.best_test_path}"
)
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
def _log_model_params(self):
"""输出模型可训练参数数量"""
total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
for model_name, state_dict in [("best validation", best_model), ("best test", best_test_model)]:
self.model.load_state_dict(state_dict)
self.logger.info(f"Testing on {model_name} model")
self._run_epoch(None, self.test_loader, "test", log_horizon=True)
def _run_epoch(self, epoch, dataloader, mode, log_horizon=False):
self.model.train() if mode == "train" else self.model.eval()
optimizer_step = mode == "train"
def _finalize_training(self, best_model, best_test_model):
self.model.load_state_dict(best_model)
self.logger.info("Testing on best validation model")
self.test(self.model, self.config, self.test_loader, self.scaler, self.logger)
self.model.load_state_dict(best_test_model)
self.logger.info("Testing on best test model")
self.test(self.model, self.config, self.test_loader, self.scaler, self.logger)
@staticmethod
def test(model, args, data_loader, scaler, logger, path=None):
"""对模型进行评估并输出性能指标"""
# 验证参数类型
if not isinstance(args, dict):
raise ValueError(f"Unsupported args type: {type(args)}")
# 确定设备和输出维度
is_full_config = "basic" in args
device = args["basic"]["device"] if is_full_config else next(model.parameters()).device
output_dim = args["train"]["output_dim"] if is_full_config else args["output_dim"]
# 获取metrics参数
train_args = args["train"] if is_full_config else args
mae_thresh, mape_thresh = train_args["mae_thresh"], train_args["mape_thresh"]
# 加载模型检查点(如果提供了路径)
if path:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["state_dict"])
model.to(device)
# 设置为评估模式并收集预测结果
model.eval()
total_loss, epoch_time = 0, time.time()
y_pred, y_true = [], []
# 不计算梯度的情况下进行预测
with torch.no_grad():
for data, target in data_loader:
# 将数据和标签移动到指定设备
data, target = data.to(device), target.to(device)
label = target[..., : output_dim]
with torch.set_grad_enabled(optimizer_step):
for data, target in tqdm(dataloader, total=len(dataloader), desc=f"{mode.capitalize()} Epoch {epoch}" if epoch else mode):
data, target = data.to(self.device), target.to(self.device)
label = target[..., :self.args["output_dim"]]
output = model(data)
y_pred.append(output.detach().cpu())
y_true.append(label.detach().cpu())
output = self.model(data)
loss = self.loss(output, label)
d_output, d_label = self.scaler.inverse_transform(output), self.scaler.inverse_transform(label)
d_loss = self.loss(d_output, d_label)
# 反归一化并计算指标
d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0))
total_loss += d_loss.item()
y_pred.append(d_output.detach().cpu())
y_true.append(d_label.detach().cpu())
# 计算并记录每个时间步的指标
for t in range(d_y_true.shape[1]):
mae, rmse, mape = all_metrics(
d_y_pred[:, t, ...],
d_y_true[:, t, ...],
mae_thresh,
mape_thresh,
)
logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
if optimizer_step and self.optimizer:
self.optimizer.zero_grad()
loss.backward()
if self.args["grad_norm"]: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"])
self.optimizer.step()
# 计算并记录平均指标
avg_mae, avg_rmse, avg_mape = all_metrics(d_y_pred, d_y_true, mae_thresh, mape_thresh)
logger.info(f"Average Horizon, MAE: {avg_mae:.4f}, RMSE: {avg_rmse:.4f}, MAPE: {avg_mape:.4f}")
y_pred, y_true = torch.cat(y_pred, dim=0), torch.cat(y_true, dim=0)
if log_horizon:
for t in range(y_true.shape[1]):
mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], self.args["mae_thresh"], self.args["mape_thresh"])
self.logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
avg_mae, avg_rmse, avg_mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"])
if epoch and mode:
self.logger.info(f"Epoch #{epoch:02d}: {mode.capitalize():<5} MAE:{avg_mae:5.2f} | RMSE:{avg_rmse:5.2f} | MAPE:{avg_mape:7.4f} | Time: {time.time()-epoch_time:.2f} s")
elif mode:
self.logger.info(f"{mode.capitalize():<5} MAE:{avg_mae:.4f} | RMSE:{avg_rmse:.4f} | MAPE:{avg_mape:.4f}")
return total_loss / len(dataloader)
def test(self, path=None):
if path:
self.model.load_state_dict(torch.load(path)["state_dict"])
self.model.to(self.device)
self._run_epoch(None, self.test_loader, "test", log_horizon=True)

View File

@ -7,132 +7,31 @@ from trainer.E32Trainer import Trainer as EXP_Trainer
from trainer.InformerTrainer import InformerTrainer
from trainer.TSTrainer import Trainer as TSTrainer
def select_trainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
kwargs,
model, loss, optimizer,
train_loader, val_loader, test_loader,
scaler, args, lr_scheduler, kwargs
):
model_name = args["basic"]["model"]
TS_model = ["HI", "PatchTST", "iTransformer"]
if model_name in TS_model:
return TSTrainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
base_args = (
model, loss, optimizer,
train_loader, val_loader, test_loader,
scaler, args, lr_scheduler
)
if model_name in {"HI", "PatchTST", "iTransformer"}:
return TSTrainer(*base_args)
match model_name:
case "STGNCDE":
return cdeTrainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
kwargs[0],
None,
)
case "STGNRDE":
return cdeTrainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
kwargs[0],
None,
)
case "DCRNN":
return DCRNN_Trainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
)
case "PDG2SEQ":
return PDG2SEQ_Trainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
)
case "STMLP":
return STMLP_Trainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
)
case "EXP":
return EXP_Trainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
)
case "Informer":
return InformerTrainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
)
case _:
return Trainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
)
trainer_map = {
"DCRNN": DCRNN_Trainer,
"PDG2SEQ": PDG2SEQ_Trainer,
"STMLP": STMLP_Trainer,
"EXP": EXP_Trainer,
"Informer": InformerTrainer,
}
if model_name in {"STGNCDE", "STGNRDE"}:
return cdeTrainer(*base_args, kwargs[0], None)
return trainer_map.get(model_name, Trainer)(*base_args)

View File

@ -18,7 +18,7 @@ def get_logger(root, name=None, debug=True):
logger.handlers.clear()
# 时间格式改为 年/月/日 时:分:秒
formatter = logging.Formatter("%(asctime)s - %(message)s", "%Y/%m/%d %H:%M:%S")
formatter = logging.Formatter("%(asctime)s - %(message)s", "%m/%d %H:%M")
# 控制台输出
console_handler = logging.StreamHandler()