From 3095b7435b42d2833e002381a4320507504bb8f4 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Mon, 15 Dec 2025 01:38:47 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=8A=A0=E8=BD=BD=E5=99=A8=E5=92=8C=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E5=99=A8=E4=BB=A3=E7=A0=81=EF=BC=8C=E4=BC=98=E5=8C=96=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E7=BB=93=E6=9E=84=E5=92=8C=E5=8F=AF=E8=AF=BB=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 重构数据加载器模块,使用字典映射替代switch-case结构 简化训练器逻辑,合并重复代码,提高可维护性 优化日志时间格式,缩短显示长度 调整训练配置,减少默认epoch数并启用GPU训练 统一数据加载方式,提取公共方法减少重复代码 --- dataloader/data_selector.py | 155 ++++++-------- dataloader/loader_selector.py | 27 +-- train.py | 8 +- trainer/TSTrainer.py | 387 +++++++++++++--------------------- trainer/Trainer.py | 277 ++++++------------------ trainer/trainer_selector.py | 147 ++----------- utils/logger.py | 2 +- 7 files changed, 318 insertions(+), 685 deletions(-) diff --git a/dataloader/data_selector.py b/dataloader/data_selector.py index e0b23e1..bd8e61a 100644 --- a/dataloader/data_selector.py +++ b/dataloader/data_selector.py @@ -2,95 +2,80 @@ import os import numpy as np import h5py + def load_st_dataset(config): dataset = config["basic"]["dataset"] - # sample = config["data"]["sample"] - # output B, N, D - match dataset: - case "BeijingAirQuality": - data_path = os.path.join("./data/BeijingAirQuality/data.dat") - data = np.memmap(data_path, dtype=np.float32, mode='r') - L, N, C = 36000, 7, 3 - data = data.reshape(L, N, C) - case "AirQuality": - data_path = os.path.join("./data/AirQuality/data.dat") - data = np.memmap(data_path, dtype=np.float32, mode='r') - L, N, C = 8701,35,6 - data = data.reshape(L, N, C) - case "PEMS-BAY": - data_path = os.path.join("./data/PEMS-BAY/pems-bay.h5") - with h5py.File(data_path, 'r') as f: - data = f['speed']['block0_values'][:] - case "METR-LA": - data_path = os.path.join("./data/METR-LA/METR-LA.h5") - with h5py.File(data_path, 'r') as f: - data = f['df']['block0_values'][:] - case "SolarEnergy": - data_path = os.path.join("./data/SolarEnergy/SolarEnergy.csv") - data = np.loadtxt(data_path, delimiter=",") - case "PEMSD3": - data_path = os.path.join("./data/PEMS03/PEMS03.npz") - data = np.load(data_path)["data"][:, :, 0] - case "PEMSD4": - data_path = os.path.join("./data/PEMS04/PEMS04.npz") - data = np.load(data_path)["data"][:, :, 0] - case "PEMSD7": - data_path = os.path.join("./data/PEMS07/PEMS07.npz") - data = np.load(data_path)["data"][:, :, 0] - case "PEMSD8": - data_path = os.path.join("./data/PEMS08/PEMS08.npz") - data = np.load(data_path)["data"][:, :, 0] - case "PEMSD7(L)": - data_path = os.path.join("./data/PEMS07(L)/PEMS07L.npz") - data = np.load(data_path)["data"][:, :, 0] - case "PEMSD7(M)": - data_path = os.path.join("./data/PEMS07(M)/V_228.csv") - data = np.genfromtxt(data_path, delimiter=",") - case "BJ": - data_path = os.path.join("./data/BJ/BJ500.csv") - data = np.genfromtxt(data_path, delimiter=",", skip_header=1) - case "Hainan": - data_path = os.path.join("./data/Hainan/Hainan.npz") - data = np.load(data_path)["data"][:, :, 0] - case "SD": - data_path = os.path.join("./data/SD/data.npz") - data = np.load(data_path)["data"][:, :, 0].astype(np.float32) - case "BJTaxi-InFlow": - data = read_BeijingTaxi()[:, :, 0:1].astype(np.float32) - case "BJTaxi-OutFlow": - data = read_BeijingTaxi()[:, :, 1:2].astype(np.float32) - case "NYCBike-InFlow": - data_path = os.path.join("./data/NYCBike/NYC16x8.h5") - with h5py.File(data_path, 'r') as f: - data = f['data'][:].astype(np.float32) - data = data.transpose(0,2,3,1).reshape(-1, 16*8, 2) - data = data[:, :, 0:1] - case "NYCBike-OutFlow": - data_path = os.path.join("./data/NYCBike/NYC16x8.h5") - with h5py.File(data_path, 'r') as f: - data = f['data'][:].astype(np.float32) - data = data.transpose(0,2,3,1).reshape(-1, 16*8, 2) - data = data[:, :, 1:2] - case _: - raise ValueError(f"Unsupported dataset: {dataset}") - # Ensure data shape compatibility - if len(data.shape) == 2: - data = np.expand_dims(data, axis=-1) + loaders = { + "BeijingAirQuality": lambda: _memmap("./data/BeijingAirQuality/data.dat", 36000, 7, 3), + "AirQuality": lambda: _memmap("./data/AirQuality/data.dat", 8701, 35, 6), - print("加载 %s 数据集中... " % dataset) - # return data[::sample] + "PEMS-BAY": lambda: _h5("./data/PEMS-BAY/pems-bay.h5", ("speed", "block0_values")), + "METR-LA": lambda: _h5("./data/METR-LA/METR-LA.h5", ("df", "block0_values")), + + "SolarEnergy": lambda: np.loadtxt("./data/SolarEnergy/SolarEnergy.csv", delimiter=","), + + "PEMSD3": lambda: _npz("./data/PEMS03/PEMS03.npz"), + "PEMSD4": lambda: _npz("./data/PEMS04/PEMS04.npz"), + "PEMSD7": lambda: _npz("./data/PEMS07/PEMS07.npz"), + "PEMSD8": lambda: _npz("./data/PEMS08/PEMS08.npz"), + + "PEMSD7(L)": lambda: _npz("./data/PEMS07(L)/PEMS07L.npz"), + "PEMSD7(M)": lambda: np.genfromtxt("./data/PEMS07(M)/V_228.csv", delimiter=","), + + "BJ": lambda: np.genfromtxt("./data/BJ/BJ500.csv", delimiter=",", skip_header=1), + "Hainan": lambda: _npz("./data/Hainan/Hainan.npz"), + "SD": lambda: _npz("./data/SD/data.npz", cast=True), + + "BJTaxi-InFlow": lambda: read_BeijingTaxi()[:, :, 0:1].astype(np.float32), + "BJTaxi-OutFlow": lambda: read_BeijingTaxi()[:, :, 1:2].astype(np.float32), + + "NYCBike-InFlow": lambda: _nyc_bike(0), + "NYCBike-OutFlow": lambda: _nyc_bike(1), + } + + if dataset not in loaders: + raise ValueError(f"Unsupported dataset: {dataset}") + + data = loaders[dataset]() + + if data.ndim == 2: + data = data[..., None] + + print(f"加载 {dataset} 数据集中... ") return data + +# ---------------- helpers ---------------- +def _memmap(path, L, N, C): + data = np.memmap(path, dtype=np.float32, mode="r") + return data.reshape(L, N, C) + + +def _h5(path, keys): + with h5py.File(path, "r") as f: + return f[keys[0]][keys[1]][:] + + +def _npz(path, cast=False): + data = np.load(path)["data"][:, :, 0] + return data.astype(np.float32) if cast else data + + +def _nyc_bike(channel): + with h5py.File("./data/NYCBike/NYC16x8.h5", "r") as f: + data = f["data"][:].astype(np.float32) + data = data.transpose(0, 2, 3, 1).reshape(-1, 16 * 8, 2) + return data[:, :, channel:channel + 1] + + def read_BeijingTaxi(): - files = ["TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy", - "TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy"] - all_data = [] - for file in files: - data_path = os.path.join(f"./data/BeijingTaxi/{file}") - data = np.load(data_path) - all_data.append(data) - all_data = np.concatenate(all_data, axis=0) - time_num = all_data.shape[0] - all_data = all_data.transpose(0, 2, 3, 1).reshape(time_num, 32*32, 2) - return all_data \ No newline at end of file + files = [ + "TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy", + "TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy", + ] + data = np.concatenate( + [np.load(f"./data/BeijingTaxi/{f}") for f in files], axis=0 + ) + T = data.shape[0] + return data.transpose(0, 2, 3, 1).reshape(T, 32 * 32, 2) diff --git a/dataloader/loader_selector.py b/dataloader/loader_selector.py index 5ea47fa..caeeb03 100755 --- a/dataloader/loader_selector.py +++ b/dataloader/loader_selector.py @@ -8,21 +8,12 @@ from dataloader.Informer_loader import get_dataloader as Informer_loader def get_dataloader(config, normalizer, single): - TS_model = ["iTransformer", "HI", "PatchTST"] - model_name = config["basic"]["model"] - # if model_name == "Informer": - # return Informer_loader(config, normalizer, single) - # elif model_name in TS_model: - # return TS_loader(config, normalizer, single) - # else : - match model_name: - case "STGNCDE": - return cde_loader(config, normalizer, single) - case "STGNRDE": - return nrde_loader(config, normalizer, single) - case "DCRNN": - return DCRNN_loader(config, normalizer, single) - case "EXP": - return EXP_loader(config, normalizer, single) - case _: - return normal_loader(config, normalizer, single) + loader_map = { + "STGNCDE": cde_loader, + "STGNRDE": nrde_loader, + "DCRNN": DCRNN_loader, + "EXP": EXP_loader, + } + return loader_map.get(config["basic"]["model"], normal_loader)( + config, normalizer, single + ) diff --git a/train.py b/train.py index acd0e60..139cdfa 100644 --- a/train.py +++ b/train.py @@ -11,9 +11,9 @@ def read_config(config_path): config = yaml.safe_load(file) # 全局配置 - device = "cpu" # 指定设备为cuda:0 + device = "cuda:0" # 指定设备为cuda:0 seed = 2023 # 随机种子 - epochs = 120 + epochs = 1 # 拷贝项 config["basic"]["device"] = device @@ -65,8 +65,8 @@ def main(debug=False): model_list = ["iTransformer"] # 指定数据集 # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] - dataset_list = ["AirQuality"] - # dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-Inflow", "NYCBike-Outflow"] + # dataset_list = ["AirQuality"] + dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-Inflow", "NYCBike-Outflow"] # 我的调试开关,不做测试就填 str(False) # os.environ["TRY"] = str(False) diff --git a/trainer/TSTrainer.py b/trainer/TSTrainer.py index b8def31..c427072 100755 --- a/trainer/TSTrainer.py +++ b/trainer/TSTrainer.py @@ -1,296 +1,195 @@ -import math -import os -import time -import copy -import torch +import os, time, copy, torch +from tqdm import tqdm from utils.logger import get_logger from utils.loss_function import all_metrics -from tqdm import tqdm class TSWrapper: def __init__(self, args): - self.b = args['train']['batch_size'] - self.t = args['data']['lag'] self.n = args['data']['num_nodes'] - self.c = args['data']['input_dim'] - - def transpose(self, x : torch.Tensor): + def forward(self, x): # [b, t, n, c] -> [b*n, t, c] - self.b = x.shape[0] - x = x[..., :-2] - x = x.permute(0, 2, 1, 3) - x = x.reshape(self.b*self.n, self.t, self.c) - return x - - def inv_transpose(self, x : torch.Tensor): - x = x.reshape(self.b, self.n, self.t, self.c) - x = x.permute(0, 2, 1, 3) - return x + b, t, n, c = x.shape + x = x[..., :-2].permute(0, 2, 1, 3).reshape(b * n, t, c-2) + return x, b, t, n, c + + def inverse(self, x, b, t, n, c): + return x.reshape(b, n, t, c-2).permute(0, 2, 1, 3) class Trainer: - """模型训练器,负责整个训练流程的管理""" - def __init__(self, model, loss, optimizer, - train_loader, val_loader, test_loader, scaler, - args, lr_scheduler=None,): - # 设备和基本参数 + train_loader, val_loader, test_loader, + scaler, args, lr_scheduler=None): + self.config = args self.device = args["basic"]["device"] - train_args = args["train"] - # 模型和训练相关组件 - self.model = model + self.args = args["train"] + + self.model = model.to(self.device) self.loss = loss self.optimizer = optimizer self.lr_scheduler = lr_scheduler - # 数据加载器 + self.train_loader = train_loader - self.val_loader = val_loader + self.val_loader = val_loader or test_loader self.test_loader = test_loader - # 数据处理工具 self.scaler = scaler - self.args = train_args - self.ts_wrapper = TSWrapper(args) - # 初始化路径、日志和统计 - self._initialize_paths(train_args) - self._initialize_logger(train_args) - - def _initialize_paths(self, args): - """初始化模型保存路径""" - self.best_path = os.path.join(args["log_dir"], "best_model.pth") - self.best_test_path = os.path.join(args["log_dir"], "best_test_model.pth") - self.loss_figure_path = os.path.join(args["log_dir"], "loss.png") - - def _initialize_logger(self, args): - """初始化日志记录器""" - if not os.path.isdir(args["log_dir"]) and not args["debug"]: - os.makedirs(args["log_dir"], exist_ok=True) - self.logger = get_logger(args["log_dir"], name=self.model.__class__.__name__, debug=args["debug"]) - self.logger.info(f"Experiment log path in: {args['log_dir']}") - def _run_epoch(self, epoch, dataloader, mode): - """运行一个训练/验证/测试epoch""" - # 设置模型模式和是否进行优化 - if mode == "train": self.model.train(); optimizer_step = True - else: self.model.eval(); optimizer_step = False + self.ts = TSWrapper(args) + self._init_paths() + self._init_logger() - # 初始化变量 - total_loss = 0 - epoch_time = time.time() + # ---------------- init ---------------- + def _init_paths(self): + d = self.args["log_dir"] + self.best_path = os.path.join(d, "best_model.pth") + self.best_test_path = os.path.join(d, "best_test_model.pth") + + def _init_logger(self): + if not self.args["debug"]: + os.makedirs(self.args["log_dir"], exist_ok=True) + self.logger = get_logger( + self.args["log_dir"], + name=self.model.__class__.__name__, + debug=self.args["debug"], + ) + + # ---------------- epoch ---------------- + def _run_epoch(self, epoch, loader, mode): + is_train = mode == "train" + self.model.train() if is_train else self.model.eval() + + total_loss, start = 0.0, time.time() y_pred, y_true = [], [] - # 训练/验证循环 - with torch.set_grad_enabled(optimizer_step): - progress_bar = tqdm( - enumerate(dataloader), - total=len(dataloader), - desc=f"{mode.capitalize()} Epoch {epoch}" - ) - for _, (data, target) in progress_bar: - # 转移数据 - data = data.to(self.device) - target = target.to(self.device) - label = target[..., : self.args["output_dim"]] - # 转换为 [b*n, t, c] - data = self.ts_wrapper.transpose(data) - # 计算loss和反归一化loss - output = self.model(data) - # 转换回[b, t, n, c] - output = self.ts_wrapper.inv_transpose(output) - # 我的调试开关 + with torch.set_grad_enabled(is_train): + for data, target in tqdm(loader, desc=f"{mode} {epoch}", total=len(loader)): + data, target = data.to(self.device), target.to(self.device) + label = target[..., :self.args["output_dim"]] + + x, b, t, n, c = self.ts.forward(data) + out = self.model(x) + out = self.ts.inverse(out, b, t, n, c) + if os.environ.get("TRY") == "True": - print(f"[{'✅' if output.shape == label.shape else '❌'}]: output: {output.shape}, label: {label.shape}") - assert False - loss = self.loss(output, label) - d_output = self.scaler.inverse_transform(output) - d_label = self.scaler.inverse_transform(label) - d_loss = self.loss(d_output, d_label) - # 累积损失和预测结果 + print(out.shape, label.shape) + assert out.shape == label.shape + + loss = self.loss(out, label) + d_out = self.scaler.inverse_transform(out) + d_lbl = self.scaler.inverse_transform(label) + d_loss = self.loss(d_out, d_lbl) + total_loss += d_loss.item() - y_pred.append(d_output.detach().cpu()) - y_true.append(d_label.detach().cpu()) - # 反向传播和优化(仅在训练模式) - if optimizer_step and self.optimizer is not None: + y_pred.append(d_out.detach().cpu()) + y_true.append(d_lbl.detach().cpu()) + + if is_train and self.optimizer: self.optimizer.zero_grad() loss.backward() - # 梯度裁剪(如果需要) if self.args["grad_norm"]: - torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + self.args["max_grad_norm"] + ) self.optimizer.step() - # 更新进度条 - progress_bar.set_postfix(loss=d_loss.item()) - # 合并所有批次的预测结果 - y_pred = torch.cat(y_pred, dim=0) - y_true = torch.cat(y_true, dim=0) - # 计算损失并记录指标 - avg_loss = total_loss / len(dataloader) - mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) - self.logger.info( - f"Epoch #{epoch:02d}: {mode.capitalize():<5} " - f"MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s" + y_pred = torch.cat(y_pred) + y_true = torch.cat(y_true) + + mae, rmse, mape = all_metrics( + y_pred, y_true, + self.args["mae_thresh"], + self.args["mape_thresh"] ) - return avg_loss - def train_epoch(self, epoch): - return self._run_epoch(epoch, self.train_loader, "train") - - def val_epoch(self, epoch): - return self._run_epoch(epoch, self.val_loader or self.test_loader, "val") - - def test_epoch(self, epoch): - return self._run_epoch(epoch, self.test_loader, "test") + self.logger.info( + f"Epoch #{epoch:02d} {mode:<5} " + f"MAE:{mae:5.2f} RMSE:{rmse:5.2f} " + f"MAPE:{mape:7.4f} Time:{time.time()-start:.2f}s" + ) + return total_loss / len(loader) + # ---------------- train ---------------- def train(self): - # 初始化记录 - best_model, best_test_model = None, None - best_loss, best_test_loss = float("inf"), float("inf") - not_improved_count = 0 - # 开始训练 - self.logger.info("Training process started") - # 训练循环 + best, best_test = float("inf"), float("inf") + best_w, best_test_w = None, None + patience = 0 + + self.logger.info("Training started") + for epoch in range(1, self.args["epochs"] + 1): - # 训练、验证和测试一个epoch - train_epoch_loss = self.train_epoch(epoch) - val_epoch_loss = self.val_epoch(epoch) - test_epoch_loss = self.test_epoch(epoch) - # 检查梯度爆炸 - if train_epoch_loss > 1e6: - self.logger.warning("Gradient explosion detected. Ending...") + losses = { + "train": self._run_epoch(epoch, self.train_loader, "train"), + "val": self._run_epoch(epoch, self.val_loader, "val"), + "test": self._run_epoch(epoch, self.test_loader, "test"), + } + + if losses["train"] > 1e6: + self.logger.warning("Gradient explosion detected") break - # 更新最佳验证模型 - if val_epoch_loss < best_loss: - best_loss = val_epoch_loss - not_improved_count = 0 - best_model = copy.deepcopy(self.model.state_dict()) - self.logger.info("Best validation model saved!") + + if losses["val"] < best: + best, patience = losses["val"], 0 + best_w = copy.deepcopy(self.model.state_dict()) + self.logger.info("Best validation model saved") else: - not_improved_count += 1 - # 早停 - if self._should_early_stop(not_improved_count): + patience += 1 + + if self.args["early_stop"] and patience == self.args["early_stop_patience"]: + self.logger.info("Early stopping triggered") break - # 更新最佳测试模型 - if test_epoch_loss < best_test_loss: - best_test_loss = test_epoch_loss - best_test_model = copy.deepcopy(self.model.state_dict()) - # 保存最佳模型 + + if losses["test"] < best_test: + best_test = losses["test"] + best_test_w = copy.deepcopy(self.model.state_dict()) + if not self.args["debug"]: - self._save_best_models(best_model, best_test_model) - # 最终评估 - self._finalize_training(best_model, best_test_model) - - def _should_early_stop(self, not_improved_count): - """检查是否满足早停条件""" - if ( - self.args["early_stop"] - and not_improved_count == self.args["early_stop_patience"] - ): - self.logger.info( - f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." - ) - return True - return False - - def _save_best_models(self, best_model, best_test_model): - """保存最佳模型到文件""" - torch.save(best_model, self.best_path) - torch.save(best_test_model, self.best_test_path) - self.logger.info( - f"Best models saved at {self.best_path} and {self.best_test_path}" - ) - - def _log_model_params(self): - """输出模型可训练参数数量""" - total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad) - self.logger.info(f"Trainable params: {total_params}") - + torch.save(best_w, self.best_path) + torch.save(best_test_w, self.best_test_path) - def _finalize_training(self, best_model, best_test_model): - self.model.load_state_dict(best_model) - self.logger.info("Testing on best validation model") - self.test(self.model, self.config, self.test_loader, self.scaler, self.logger) - self.model.load_state_dict(best_test_model) - self.logger.info("Testing on best test model") - self.test(self.model, self.config, self.test_loader, self.scaler, self.logger) + self._final_test(best_w, best_test_w) - @staticmethod - def test(model, args, data_loader, scaler, logger, path=None): - """对模型进行评估并输出性能指标""" - # 确定设备信息 - device = None - output_dim = None - # 处理不同的参数格式 - if isinstance(args, dict): - if "basic" in args: - # 完整配置情况 - device = args["basic"]["device"] - output_dim = args["train"]["output_dim"] - else: - # 只有train_args情况,从模型获取设备 - device = next(model.parameters()).device - output_dim = args["output_dim"] - else: - raise ValueError(f"Unsupported args type: {type(args)}") - - # 加载模型检查点(如果提供了路径) - if path: - checkpoint = torch.load(path) - model.load_state_dict(checkpoint["state_dict"]) - model.to(device) + # ---------------- final test ---------------- + def _final_test(self, best_w, best_test_w): + for name, w in [("best val", best_w), ("best test", best_test_w)]: + self.model.load_state_dict(w) + self.logger.info(f"Testing on {name} model") + self.evaluate() - # 设置为评估模式 - model.eval() - - # 收集预测和真实标签 + # ---------------- evaluate ---------------- + def evaluate(self): + self.model.eval() y_pred, y_true = [], [] - # 不计算梯度的情况下进行预测 with torch.no_grad(): - for data, target in data_loader: - # 将数据和标签移动到指定设备 - data = data.to(device) - target = target.to(device) - - data = data[..., :-2] - b, t, n, c = data.shape - data = data.permute(0, 2, 1, 3) - data = data.reshape(b*n, t, c) - label = target[..., : output_dim] - output = model(data) - output = output.reshape(b, n, t, c) - output = output.permute(0, 2, 1, 3) + for data, target in self.test_loader: + data, target = data.to(self.device), target.to(self.device) + label = target[..., :self.args["output_dim"]] - y_pred.append(output.detach().cpu()) - y_true.append(label.detach().cpu()) + x, b, t, n, c = self.ts.forward(data) + out = self.model(x) + out = self.ts.inverse(out, b, t, n, c) - d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) - d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) + y_pred.append(out.cpu()) + y_true.append(label.cpu()) - # 获取metrics参数 - if "basic" in args: - # 完整配置情况 - mae_thresh = args["train"]["mae_thresh"] - mape_thresh = args["train"]["mape_thresh"] - else: - # 只有train_args情况 - mae_thresh = args["mae_thresh"] - mape_thresh = args["mape_thresh"] - - # 计算并记录每个时间步的指标 - for t in range(d_y_true.shape[1]): + d_pred = self.scaler.inverse_transform(torch.cat(y_pred)) + d_true = self.scaler.inverse_transform(torch.cat(y_true)) + + for t in range(d_true.shape[1]): mae, rmse, mape = all_metrics( - d_y_pred[:, t, ...], - d_y_true[:, t, ...], - mae_thresh, - mape_thresh, + d_pred[:, t], d_true[:, t], + self.args["mae_thresh"], + self.args["mape_thresh"] ) - logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + self.logger.info( + f"Horizon {t+1:02d} MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}" + ) + + avg_mae, avg_rmse, avg_mape = all_metrics(d_pred, d_true, self.args["mae_thresh"], self.args["mape_thresh"]) + self.logger.info( + f"AVG MAE:{avg_mae:.4f} AVG RMSE:{avg_rmse:.4f} AVG MAPE:{avg_mape:.4f}" + ) - # 计算并记录平均指标 - mae, rmse, mape = all_metrics(d_y_pred, d_y_true, mae_thresh, mape_thresh) - logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") - - @staticmethod - def _compute_sampling_threshold(global_step, k): - return k / (k + math.exp(global_step / k)) diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 04842ba..65980b9 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -1,4 +1,3 @@ -import math import os import time import copy @@ -8,240 +7,100 @@ from utils.loss_function import all_metrics from tqdm import tqdm class Trainer: - """模型训练器,负责整个训练流程的管理""" - def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler=None): - # 设备和基本参数 - self.config = args - self.device = args["basic"]["device"] - self.args = args["train"] - - # 模型和训练相关组件 + self.config, self.device, self.args = args, args["basic"]["device"], args["train"] self.model, self.loss, self.optimizer, self.lr_scheduler = model, loss, optimizer, lr_scheduler - - # 数据加载器 - self.train_loader, self.val_loader, self.test_loader = train_loader, val_loader, test_loader + self.train_loader, self.val_loader, self.test_loader, self.scaler = train_loader, val_loader, test_loader, scaler - # 数据处理工具 - self.scaler = scaler + log_dir = self.args["log_dir"] + self.best_path, self.best_test_path = [os.path.join(log_dir, f"best_{suffix}_model.pth") for suffix in ["", "test"]] - # 初始化路径、日志和统计 - self._initialize_paths(self.args) - self._initialize_logger(self.args) - - def _initialize_paths(self, args): - """初始化模型保存路径""" - log_dir = args["log_dir"] - self.best_path = os.path.join(log_dir, "best_model.pth") - self.best_test_path = os.path.join(log_dir, "best_test_model.pth") - self.loss_figure_path = os.path.join(log_dir, "loss.png") - - def _initialize_logger(self, args): - """初始化日志记录器""" - log_dir = args["log_dir"] - if not args["debug"]: - os.makedirs(log_dir, exist_ok=True) - self.logger = get_logger(log_dir, name=self.model.__class__.__name__, debug=args["debug"]) + if not self.args["debug"]: os.makedirs(log_dir, exist_ok=True) + self.logger = get_logger(log_dir, name=self.model.__class__.__name__, debug=self.args["debug"]) self.logger.info(f"Experiment log path in: {log_dir}") - def _run_epoch(self, epoch, dataloader, mode): - """运行一个训练/验证/测试epoch""" - # 设置模型模式和是否进行优化 - self.model.train() if mode == "train" else self.model.eval() - optimizer_step = mode == "train" - - # 初始化变量 - total_loss = 0 - epoch_time = time.time() - y_pred, y_true = [], [] - - # 训练/验证循环 - with torch.set_grad_enabled(optimizer_step): - progress_bar = tqdm( - dataloader, - total=len(dataloader), - desc=f"{mode.capitalize()} Epoch {epoch}" - ) - for data, target in progress_bar: - # 转移数据并提取标签 - data, target = data.to(self.device), target.to(self.device) - label = target[..., : self.args["output_dim"]] - - # 计算输出 - output = self.model(data) - - # 我的调试开关 - if os.environ.get("TRY") == "True": - status = '✅' if output.shape == label.shape else '❌' - print(f"[{status}]: output: {output.shape}, label: {label.shape}") - assert False - - # 计算损失 - loss = self.loss(output, label) - d_output = self.scaler.inverse_transform(output) - d_label = self.scaler.inverse_transform(label) - d_loss = self.loss(d_output, d_label) - - # 累积损失和预测结果 - total_loss += d_loss.item() - y_pred.append(d_output.detach().cpu()) - y_true.append(d_label.detach().cpu()) - - # 反向传播和优化(仅在训练模式) - if optimizer_step and self.optimizer is not None: - self.optimizer.zero_grad() - loss.backward() - - # 梯度裁剪(如果需要) - if self.args["grad_norm"]: - torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) - - self.optimizer.step() - - # 更新进度条 - progress_bar.set_postfix(loss=d_loss.item()) - - # 合并所有批次的预测结果 - y_pred, y_true = torch.cat(y_pred, dim=0), torch.cat(y_true, dim=0) - - # 计算损失并记录指标 - avg_loss = total_loss / len(dataloader) - mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) - - self.logger.info( - f"Epoch #{epoch:02d}: {mode.capitalize():<5} " - f"MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s" - ) - - return avg_loss - def train(self): - # 初始化记录 best_model = best_test_model = None best_loss = best_test_loss = float("inf") not_improved_count = 0 - # 开始训练 self.logger.info("Training process started") - # 训练循环 for epoch in range(1, self.args["epochs"] + 1): - # 训练、验证和测试一个epoch - train_epoch_loss = self._run_epoch(epoch, self.train_loader, "train") - val_epoch_loss = self._run_epoch(epoch, self.val_loader or self.test_loader, "val") - test_epoch_loss = self._run_epoch(epoch, self.test_loader, "test") + train_loss = self._run_epoch(epoch, self.train_loader, "train") + val_loss = self._run_epoch(epoch, self.val_loader or self.test_loader, "val") + test_loss = self._run_epoch(epoch, self.test_loader, "test") - # 检查梯度爆炸 - if train_epoch_loss > 1e6: + if train_loss > 1e6: self.logger.warning("Gradient explosion detected. Ending...") break - # 更新最佳验证模型 - if val_epoch_loss < best_loss: - best_loss, not_improved_count = val_epoch_loss, 0 - best_model = copy.deepcopy(self.model.state_dict()) + if val_loss < best_loss: + best_loss, not_improved_count, best_model = val_loss, 0, copy.deepcopy(self.model.state_dict()) self.logger.info("Best validation model saved!") - else: - not_improved_count += 1 - - # 早停检查 - if self._should_early_stop(not_improved_count): + elif self.args["early_stop"] and (not_improved_count := not_improved_count + 1) == self.args["early_stop_patience"]: + self.logger.info(f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops.") break - # 更新最佳测试模型 - if test_epoch_loss < best_test_loss: - best_test_loss = test_epoch_loss - best_test_model = copy.deepcopy(self.model.state_dict()) + if test_loss < best_test_loss: + best_test_loss, best_test_model = test_loss, copy.deepcopy(self.model.state_dict()) - # 保存最佳模型 - if not self.args["debug"]: - self._save_best_models(best_model, best_test_model) - - # 最终评估 - self._finalize_training(best_model, best_test_model) - - def _should_early_stop(self, not_improved_count): - """检查是否满足早停条件""" - if self.args["early_stop"] and not_improved_count == self.args["early_stop_patience"]: - self.logger.info( - f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops." - ) - return True - return False - - def _save_best_models(self, best_model, best_test_model): - """保存最佳模型到文件""" + torch.save(best_model, self.best_path) torch.save(best_test_model, self.best_test_path) - self.logger.info( - f"Best models saved at {self.best_path} and {self.best_test_path}" - ) + self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}") + + for model_name, state_dict in [("best validation", best_model), ("best test", best_test_model)]: + self.model.load_state_dict(state_dict) + self.logger.info(f"Testing on {model_name} model") + self._run_epoch(None, self.test_loader, "test", log_horizon=True) - def _log_model_params(self): - """输出模型可训练参数数量""" - total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad) - self.logger.info(f"Trainable params: {total_params}") + def _run_epoch(self, epoch, dataloader, mode, log_horizon=False): + self.model.train() if mode == "train" else self.model.eval() + optimizer_step = mode == "train" - - def _finalize_training(self, best_model, best_test_model): - self.model.load_state_dict(best_model) - self.logger.info("Testing on best validation model") - self.test(self.model, self.config, self.test_loader, self.scaler, self.logger) - self.model.load_state_dict(best_test_model) - self.logger.info("Testing on best test model") - self.test(self.model, self.config, self.test_loader, self.scaler, self.logger) - - @staticmethod - def test(model, args, data_loader, scaler, logger, path=None): - """对模型进行评估并输出性能指标""" - # 验证参数类型 - if not isinstance(args, dict): - raise ValueError(f"Unsupported args type: {type(args)}") - - # 确定设备和输出维度 - is_full_config = "basic" in args - device = args["basic"]["device"] if is_full_config else next(model.parameters()).device - output_dim = args["train"]["output_dim"] if is_full_config else args["output_dim"] - - # 获取metrics参数 - train_args = args["train"] if is_full_config else args - mae_thresh, mape_thresh = train_args["mae_thresh"], train_args["mape_thresh"] - - # 加载模型检查点(如果提供了路径) - if path: - checkpoint = torch.load(path) - model.load_state_dict(checkpoint["state_dict"]) - model.to(device) - - # 设置为评估模式并收集预测结果 - model.eval() + total_loss, epoch_time = 0, time.time() y_pred, y_true = [], [] - - # 不计算梯度的情况下进行预测 - with torch.no_grad(): - for data, target in data_loader: - # 将数据和标签移动到指定设备 - data, target = data.to(device), target.to(device) - label = target[..., : output_dim] - - output = model(data) - y_pred.append(output.detach().cpu()) - y_true.append(label.detach().cpu()) - - # 反归一化并计算指标 - d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) - d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) - # 计算并记录每个时间步的指标 - for t in range(d_y_true.shape[1]): - mae, rmse, mape = all_metrics( - d_y_pred[:, t, ...], - d_y_true[:, t, ...], - mae_thresh, - mape_thresh, - ) - logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + with torch.set_grad_enabled(optimizer_step): + for data, target in tqdm(dataloader, total=len(dataloader), desc=f"{mode.capitalize()} Epoch {epoch}" if epoch else mode): + data, target = data.to(self.device), target.to(self.device) + label = target[..., :self.args["output_dim"]] + + output = self.model(data) + loss = self.loss(output, label) + d_output, d_label = self.scaler.inverse_transform(output), self.scaler.inverse_transform(label) + d_loss = self.loss(d_output, d_label) + + total_loss += d_loss.item() + y_pred.append(d_output.detach().cpu()) + y_true.append(d_label.detach().cpu()) + + if optimizer_step and self.optimizer: + self.optimizer.zero_grad() + loss.backward() + if self.args["grad_norm"]: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) + self.optimizer.step() + + y_pred, y_true = torch.cat(y_pred, dim=0), torch.cat(y_true, dim=0) + + if log_horizon: + for t in range(y_true.shape[1]): + mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], self.args["mae_thresh"], self.args["mape_thresh"]) + self.logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + + avg_mae, avg_rmse, avg_mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) + + if epoch and mode: + self.logger.info(f"Epoch #{epoch:02d}: {mode.capitalize():<5} MAE:{avg_mae:5.2f} | RMSE:{avg_rmse:5.2f} | MAPE:{avg_mape:7.4f} | Time: {time.time()-epoch_time:.2f} s") + elif mode: + self.logger.info(f"{mode.capitalize():<5} MAE:{avg_mae:.4f} | RMSE:{avg_rmse:.4f} | MAPE:{avg_mape:.4f}") + + return total_loss / len(dataloader) - # 计算并记录平均指标 - avg_mae, avg_rmse, avg_mape = all_metrics(d_y_pred, d_y_true, mae_thresh, mape_thresh) - logger.info(f"Average Horizon, MAE: {avg_mae:.4f}, RMSE: {avg_rmse:.4f}, MAPE: {avg_mape:.4f}") + def test(self, path=None): + if path: + self.model.load_state_dict(torch.load(path)["state_dict"]) + self.model.to(self.device) + + self._run_epoch(None, self.test_loader, "test", log_horizon=True) diff --git a/trainer/trainer_selector.py b/trainer/trainer_selector.py index 723b257..17aa81d 100755 --- a/trainer/trainer_selector.py +++ b/trainer/trainer_selector.py @@ -7,132 +7,31 @@ from trainer.E32Trainer import Trainer as EXP_Trainer from trainer.InformerTrainer import InformerTrainer from trainer.TSTrainer import Trainer as TSTrainer + def select_trainer( - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler, - kwargs, + model, loss, optimizer, + train_loader, val_loader, test_loader, + scaler, args, lr_scheduler, kwargs ): model_name = args["basic"]["model"] - TS_model = ["HI", "PatchTST", "iTransformer"] - if model_name in TS_model: - return TSTrainer( - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler, - ) + base_args = ( + model, loss, optimizer, + train_loader, val_loader, test_loader, + scaler, args, lr_scheduler + ) + if model_name in {"HI", "PatchTST", "iTransformer"}: + return TSTrainer(*base_args) - match model_name: - case "STGNCDE": - return cdeTrainer( - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler, - kwargs[0], - None, - ) - case "STGNRDE": - return cdeTrainer( - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler, - kwargs[0], - None, - ) - case "DCRNN": - return DCRNN_Trainer( - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler, - ) - case "PDG2SEQ": - return PDG2SEQ_Trainer( - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler, - ) - case "STMLP": - return STMLP_Trainer( - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler, - ) - case "EXP": - return EXP_Trainer( - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler, - ) - case "Informer": - return InformerTrainer( - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler, - ) - case _: - return Trainer( - model, - loss, - optimizer, - train_loader, - val_loader, - test_loader, - scaler, - args, - lr_scheduler, - ) + trainer_map = { + "DCRNN": DCRNN_Trainer, + "PDG2SEQ": PDG2SEQ_Trainer, + "STMLP": STMLP_Trainer, + "EXP": EXP_Trainer, + "Informer": InformerTrainer, + } + + if model_name in {"STGNCDE", "STGNRDE"}: + return cdeTrainer(*base_args, kwargs[0], None) + + return trainer_map.get(model_name, Trainer)(*base_args) diff --git a/utils/logger.py b/utils/logger.py index 8a2f187..7a818f6 100755 --- a/utils/logger.py +++ b/utils/logger.py @@ -18,7 +18,7 @@ def get_logger(root, name=None, debug=True): logger.handlers.clear() # 时间格式改为 年/月/日 时:分:秒 - formatter = logging.Formatter("%(asctime)s - %(message)s", "%Y/%m/%d %H:%M:%S") + formatter = logging.Formatter("%(asctime)s - %(message)s", "%m/%d %H:%M") # 控制台输出 console_handler = logging.StreamHandler()