From 5e52f23c8d6c4f4c1275d79076c052e28efd99f3 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Mon, 15 Dec 2025 21:23:04 +0800 Subject: [PATCH] =?UTF-8?q?fix(config):=20=E4=BF=AE=E6=AD=A3=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E6=96=87=E4=BB=B6=E5=91=BD=E5=90=8D=E4=B8=8D=E4=B8=80?= =?UTF-8?q?=E8=87=B4=E9=97=AE=E9=A2=98=E5=B9=B6=E6=9B=B4=E6=96=B0=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit refactor(trainer): 重构训练器代码,优化反归一化处理和形状转换逻辑 style(trainer): 简化代码格式,提高可读性 chore: 更新训练脚本中的模型和数据集列表 --- ...{BJTaxi-Inflow.yaml => BJTaxi-InFlow.yaml} | 0 ...YCBike-Inflow.yaml => NYCBike-InFlow.yaml} | 0 ...Bike-Outflow.yaml => NYCBike-OutFlow.yaml} | 0 ...{BJTaxi-Inflow.yaml => BJTaxi-InFlow.yaml} | 0 ...JTaxi-Outflow.yaml => BJTaxi-OutFlow.yaml} | 0 ...YCBike-Inflow.yaml => NYCBike-InFlow.yaml} | 0 ...Bike-Outflow.yaml => NYCBike-OutFlow.yaml} | 0 ...{BJTaxi-Inflow.yaml => BJTaxi-InFlow.yaml} | 0 ...JTaxi-Outflow.yaml => BJTaxi-OutFlow.yaml} | 0 ...YCBike-Inflow.yaml => NYCBike-InFlow.yaml} | 0 ...Bike-Outflow.yaml => NYCBike-OutFlow.yaml} | 0 train.py | 10 +- trainer/TSTrainer.py | 99 +++++++------ trainer/Trainer.py | 135 ++++-------------- 14 files changed, 85 insertions(+), 159 deletions(-) rename config/HI/{BJTaxi-Inflow.yaml => BJTaxi-InFlow.yaml} (100%) rename config/HI/{NYCBike-Inflow.yaml => NYCBike-InFlow.yaml} (100%) rename config/HI/{NYCBike-Outflow.yaml => NYCBike-OutFlow.yaml} (100%) rename config/Informer/{BJTaxi-Inflow.yaml => BJTaxi-InFlow.yaml} (100%) rename config/Informer/{BJTaxi-Outflow.yaml => BJTaxi-OutFlow.yaml} (100%) rename config/Informer/{NYCBike-Inflow.yaml => NYCBike-InFlow.yaml} (100%) rename config/Informer/{NYCBike-Outflow.yaml => NYCBike-OutFlow.yaml} (100%) rename config/iTransformer/{BJTaxi-Inflow.yaml => BJTaxi-InFlow.yaml} (100%) rename config/iTransformer/{BJTaxi-Outflow.yaml => BJTaxi-OutFlow.yaml} (100%) rename config/iTransformer/{NYCBike-Inflow.yaml => NYCBike-InFlow.yaml} (100%) rename config/iTransformer/{NYCBike-Outflow.yaml => NYCBike-OutFlow.yaml} (100%) diff --git a/config/HI/BJTaxi-Inflow.yaml b/config/HI/BJTaxi-InFlow.yaml similarity index 100% rename from config/HI/BJTaxi-Inflow.yaml rename to config/HI/BJTaxi-InFlow.yaml diff --git a/config/HI/NYCBike-Inflow.yaml b/config/HI/NYCBike-InFlow.yaml similarity index 100% rename from config/HI/NYCBike-Inflow.yaml rename to config/HI/NYCBike-InFlow.yaml diff --git a/config/HI/NYCBike-Outflow.yaml b/config/HI/NYCBike-OutFlow.yaml similarity index 100% rename from config/HI/NYCBike-Outflow.yaml rename to config/HI/NYCBike-OutFlow.yaml diff --git a/config/Informer/BJTaxi-Inflow.yaml b/config/Informer/BJTaxi-InFlow.yaml similarity index 100% rename from config/Informer/BJTaxi-Inflow.yaml rename to config/Informer/BJTaxi-InFlow.yaml diff --git a/config/Informer/BJTaxi-Outflow.yaml b/config/Informer/BJTaxi-OutFlow.yaml similarity index 100% rename from config/Informer/BJTaxi-Outflow.yaml rename to config/Informer/BJTaxi-OutFlow.yaml diff --git a/config/Informer/NYCBike-Inflow.yaml b/config/Informer/NYCBike-InFlow.yaml similarity index 100% rename from config/Informer/NYCBike-Inflow.yaml rename to config/Informer/NYCBike-InFlow.yaml diff --git a/config/Informer/NYCBike-Outflow.yaml b/config/Informer/NYCBike-OutFlow.yaml similarity index 100% rename from config/Informer/NYCBike-Outflow.yaml rename to config/Informer/NYCBike-OutFlow.yaml diff --git a/config/iTransformer/BJTaxi-Inflow.yaml b/config/iTransformer/BJTaxi-InFlow.yaml similarity index 100% rename from config/iTransformer/BJTaxi-Inflow.yaml rename to config/iTransformer/BJTaxi-InFlow.yaml diff --git a/config/iTransformer/BJTaxi-Outflow.yaml b/config/iTransformer/BJTaxi-OutFlow.yaml similarity index 100% rename from config/iTransformer/BJTaxi-Outflow.yaml rename to config/iTransformer/BJTaxi-OutFlow.yaml diff --git a/config/iTransformer/NYCBike-Inflow.yaml b/config/iTransformer/NYCBike-InFlow.yaml similarity index 100% rename from config/iTransformer/NYCBike-Inflow.yaml rename to config/iTransformer/NYCBike-InFlow.yaml diff --git a/config/iTransformer/NYCBike-Outflow.yaml b/config/iTransformer/NYCBike-OutFlow.yaml similarity index 100% rename from config/iTransformer/NYCBike-Outflow.yaml rename to config/iTransformer/NYCBike-OutFlow.yaml diff --git a/train.py b/train.py index 83d056a..b0b5af1 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,7 @@ def read_config(config_path): # 全局配置 device = "cuda:0" # 指定设备为cuda:0 seed = 2023 # 随机种子 - epochs = 1 # 训练轮数 + epochs = 100 # 训练轮数 # 拷贝项 config["basic"]["device"] = device @@ -91,8 +91,8 @@ if __name__ == "__main__": # 调试用 # model_list = ["iTransformer", "PatchTST", "HI"] # model_list = ["ASTRA_v2", "GWN", "REPST", "STAEFormer", "MTGNN"] - model_list = ["MTGNN"] - # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] - dataset_list = ["AirQuality"] + model_list = ["iTransformer"] + dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] + # dataset_list = ["AirQuality"] # dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-InFlow", "NYCBike-OutFlow"] - main(model_list, dataset_list, debug = False) \ No newline at end of file + main(model_list, dataset_list, debug = True) \ No newline at end of file diff --git a/trainer/TSTrainer.py b/trainer/TSTrainer.py index 5ba71f2..81ee54f 100755 --- a/trainer/TSTrainer.py +++ b/trainer/TSTrainer.py @@ -3,25 +3,13 @@ from tqdm import tqdm from utils.logger import get_logger from utils.loss_function import all_metrics -class TSWrapper: - def __init__(self, args): - self.n = args['data']['num_nodes'] - - def forward(self, x): - # [b, t, n, c] -> [b*n, t, c] - b, t, n, c = x.shape - x = x[..., :-2].permute(0, 2, 1, 3).reshape(b * n, t, c-2) - return x, b, t, n, c - - def inverse(self, x, b, t, n, c): - return x.reshape(b, n, t, c-2).permute(0, 2, 1, 3) - class Trainer: def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler=None): + self.config = args self.device = args["basic"]["device"] self.args = args["train"] @@ -35,7 +23,24 @@ class Trainer: self.test_loader = test_loader self.scaler = scaler - self.ts = TSWrapper(args) + # ---------- shape magic (replace TSWrapper) ---------- + self.pack = lambda x: ( + x[..., :-2] + .permute(0, 2, 1, 3) + .reshape(-1, x.size(1), x.size(3) - 2), + x.shape + ) + self.unpack = lambda y, s: ( + y.reshape(s[0], s[2], s[1], -1) + .permute(0, 2, 1, 3) + ) + + # ---------- inverse scaler ---------- + self.inv = lambda x: torch.cat( + [s.inverse_transform(x[..., i:i+1]) for i, s in enumerate(self.scaler)], + dim=-1 + ) + self._init_paths() self._init_logger() @@ -51,7 +56,7 @@ class Trainer: self.logger = get_logger( self.args["log_dir"], name=self.model.__class__.__name__, - debug=self.args["debug"], + debug=self.args["debug"] ) # ---------------- epoch ---------------- @@ -67,21 +72,17 @@ class Trainer: data, target = data.to(self.device), target.to(self.device) label = target[..., :self.args["output_dim"]] - x, b, t, n, c = self.ts.forward(data) - out = self.model(x) - out = self.ts.inverse(out, b, t, n, c) + x, shp = self.pack(data) + out = self.unpack(self.model(x), shp) if os.environ.get("TRY") == "True": - if out.shape == label.shape: - print("shape true") - assert False - else: - print("shape false") - assert False + print(f"out:{out.shape} label:{label.shape}", + "✅" if out.shape == label.shape else "❌") + assert False loss = self.loss(out, label) - d_out = self.scaler.inverse_transform(out) - d_lbl = self.scaler.inverse_transform(label) + + d_out, d_lbl = self.inv(out), self.inv(label) d_loss = self.loss(d_out, d_lbl) total_loss += d_loss.item() @@ -98,9 +99,7 @@ class Trainer: ) self.optimizer.step() - y_pred = torch.cat(y_pred) - y_true = torch.cat(y_true) - + y_pred, y_true = torch.cat(y_pred), torch.cat(y_true) mae, rmse, mape = all_metrics( y_pred, y_true, self.args["mae_thresh"], @@ -110,23 +109,28 @@ class Trainer: self.logger.info( f"Epoch #{epoch:02d} {mode:<5} " f"MAE:{mae:5.2f} RMSE:{rmse:5.2f} " - f"MAPE:{mape:7.4f} Time:{time.time()-start:.2f}s" + f"MAPE:{mape:7.4f} " + f"Time:{time.time() - start:.2f}s" ) + return total_loss / len(loader) # ---------------- train ---------------- def train(self): - best, best_test = float("inf"), float("inf") - best_w, best_test_w = None, None + best = best_test = float("inf") + best_w = best_test_w = None patience = 0 self.logger.info("Training started") for epoch in range(1, self.args["epochs"] + 1): losses = { - "train": self._run_epoch(epoch, self.train_loader, "train"), - "val": self._run_epoch(epoch, self.val_loader, "val"), - "test": self._run_epoch(epoch, self.test_loader, "test"), + k: self._run_epoch(epoch, l, k) + for k, l in [ + ("train", self.train_loader), + ("val", self.val_loader), + ("test", self.test_loader) + ] } if losses["train"] > 1e6: @@ -171,15 +175,14 @@ class Trainer: data, target = data.to(self.device), target.to(self.device) label = target[..., :self.args["output_dim"]] - x, b, t, n, c = self.ts.forward(data) - out = self.model(x) - out = self.ts.inverse(out, b, t, n, c) + x, shp = self.pack(data) + out = self.unpack(self.model(x), shp) y_pred.append(out.cpu()) y_true.append(label.cpu()) - d_pred = self.scaler.inverse_transform(torch.cat(y_pred)) - d_true = self.scaler.inverse_transform(torch.cat(y_true)) + d_pred = self.inv(torch.cat(y_pred)) + d_true = self.inv(torch.cat(y_true)) for t in range(d_true.shape[1]): mae, rmse, mape = all_metrics( @@ -188,11 +191,15 @@ class Trainer: self.args["mape_thresh"] ) self.logger.info( - f"Horizon {t+1:02d} MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}" + f"Horizon {t+1:02d} " + f"MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}" ) - - avg_mae, avg_rmse, avg_mape = all_metrics(d_pred, d_true, self.args["mae_thresh"], self.args["mape_thresh"]) - self.logger.info( - f"AVG MAE:{avg_mae:.4f} AVG RMSE:{avg_rmse:.4f} AVG MAPE:{avg_mape:.4f}" - ) + mae, rmse, mape = all_metrics( + d_pred, d_true, + self.args["mae_thresh"], + self.args["mape_thresh"] + ) + self.logger.info( + f"AVG MAE:{mae:.4f} AVG RMSE:{rmse:.4f} AVG MAPE:{mape:.4f}" + ) diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 7c9aee0..e0838b5 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -3,59 +3,29 @@ from tqdm import tqdm from utils.logger import get_logger from utils.loss_function import all_metrics - class Trainer: - def __init__(self, model, loss, optimizer, - train_loader, val_loader, test_loader, - scaler, args, lr_scheduler=None): - - self.config = args - self.device = args["basic"]["device"] - self.args = args["train"] - - self.model = model.to(self.device) - self.loss = loss - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - - self.train_loader = train_loader - self.val_loader = val_loader or test_loader - self.test_loader = test_loader + def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler=None): + self.device, self.args = args["basic"]["device"], args["train"] + self.model, self.loss, self.optimizer, self.lr_scheduler = model.to(self.device), loss, optimizer, lr_scheduler + self.train_loader, self.val_loader, self.test_loader = train_loader, val_loader or test_loader, test_loader self.scaler = scaler - - # ===== 新增:统一反归一化接口(单 scaler / 多 scaler 通吃)===== - self.inv = ( - (lambda x: self.scaler.inverse_transform(x)) - if not isinstance(self.scaler, (list, tuple)) - else (lambda x: torch.cat( - [s.inverse_transform(x[..., i:i+1]) - for i, s in enumerate(self.scaler)], - dim=-1)) - ) - + self.inv = lambda x: torch.cat([s.inverse_transform(x[..., i:i+1]) for i, s in enumerate(self.scaler)], dim=-1) # 对每个维度调用反归一化器后cat self._init_paths() self._init_logger() # ---------------- init ---------------- def _init_paths(self): d = self.args["log_dir"] - self.best_path = os.path.join(d, "best_model.pth") - self.best_test_path = os.path.join(d, "best_test_model.pth") + self.best_path, self.best_test_path = os.path.join(d, "best_model.pth"), os.path.join(d, "best_test_model.pth") def _init_logger(self): - if not self.args["debug"]: - os.makedirs(self.args["log_dir"], exist_ok=True) - self.logger = get_logger( - self.args["log_dir"], - name=self.model.__class__.__name__, - debug=self.args["debug"], - ) + if not self.args["debug"]: os.makedirs(self.args["log_dir"], exist_ok=True) + self.logger = get_logger(self.args["log_dir"], name=self.model.__class__.__name__, debug=self.args["debug"]) # ---------------- epoch ---------------- def _run_epoch(self, epoch, loader, mode): is_train = mode == "train" self.model.train() if is_train else self.model.eval() - total_loss, start = 0.0, time.time() y_pred, y_true = [], [] @@ -63,20 +33,12 @@ class Trainer: for data, target in tqdm(loader, desc=f"{mode} {epoch}", total=len(loader)): data, target = data.to(self.device), target.to(self.device) label = target[..., :self.args["output_dim"]] - out = self.model(data) - - if os.environ.get("TRY") == "True": - print(f"out: {out.shape}, label: {label.shape}") - assert False - + if os.environ.get("TRY") == "True": print(f"out: {out.shape}, label: {label.shape} \ + {'✅' if out.shape == label.shape else '❌'}"); assert False loss = self.loss(out, label) - - # ===== 修改点:反归一化 ===== - d_out = self.inv(out) - d_lbl = self.inv(label) + d_out, d_lbl = self.inv(out), self.inv(label) # 反归一化 d_loss = self.loss(d_out, d_lbl) - total_loss += d_loss.item() y_pred.append(d_out.detach().cpu()) y_true.append(d_lbl.detach().cpu()) @@ -84,27 +46,12 @@ class Trainer: if is_train and self.optimizer: self.optimizer.zero_grad() loss.backward() - if self.args["grad_norm"]: - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), - self.args["max_grad_norm"] - ) + if self.args["grad_norm"]: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) self.optimizer.step() - y_pred = torch.cat(y_pred) - y_true = torch.cat(y_true) - - mae, rmse, mape = all_metrics( - y_pred, y_true, - self.args["mae_thresh"], - self.args["mape_thresh"] - ) - - self.logger.info( - f"Epoch #{epoch:02d} {mode:<5} " - f"MAE:{mae:5.2f} RMSE:{rmse:5.2f} " - f"MAPE:{mape:7.4f} Time:{time.time()-start:.2f}s" - ) + y_pred, y_true = torch.cat(y_pred), torch.cat(y_true) + mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) + self.logger.info(f"Epoch #{epoch:02d} {mode:<5} MAE:{mae:5.2f} RMSE:{rmse:5.2f} MAPE:{mape:7.4f} Time:{time.time()-start:.2f}s") return total_loss / len(loader) # ---------------- train ---------------- @@ -112,37 +59,24 @@ class Trainer: best, best_test = float("inf"), float("inf") best_w, best_test_w = None, None patience = 0 - self.logger.info("Training started") for epoch in range(1, self.args["epochs"] + 1): losses = { "train": self._run_epoch(epoch, self.train_loader, "train"), - "val": self._run_epoch(epoch, self.val_loader, "val"), - "test": self._run_epoch(epoch, self.test_loader, "test"), + "val": self._run_epoch(epoch, self.val_loader, "val"), + "test": self._run_epoch(epoch, self.test_loader, "test"), } - if losses["train"] > 1e6: - self.logger.warning("Gradient explosion detected") - break - - if losses["val"] < best: - best, patience = losses["val"], 0 - best_w = copy.deepcopy(self.model.state_dict()) - else: - patience += 1 - - if self.args["early_stop"] and patience == self.args["early_stop_patience"]: - break - - if losses["test"] < best_test: - best_test = losses["test"] - best_test_w = copy.deepcopy(self.model.state_dict()) + if losses["train"] > 1e6: self.logger.warning("Gradient explosion detected"); break + if losses["val"] < best: best, patience, best_w = losses["val"], 0, copy.deepcopy(self.model.state_dict()) + else: patience += 1 + if self.args["early_stop"] and patience == self.args["early_stop_patience"]: break + if losses["test"] < best_test: best_test, best_test_w = losses["test"], copy.deepcopy(self.model.state_dict()) if not self.args["debug"]: torch.save(best_w, self.best_path) torch.save(best_test_w, self.best_test_path) - self._final_test(best_w, best_test_w) # ---------------- final test ---------------- @@ -164,25 +98,10 @@ class Trainer: y_pred.append(self.model(data).cpu()) y_true.append(label.cpu()) - # ===== 修改点:反归一化 ===== - d_pred = self.inv(torch.cat(y_pred)) - d_true = self.inv(torch.cat(y_true)) - + d_pred, d_true = self.inv(torch.cat(y_pred)), self.inv(torch.cat(y_true)) # 反归一化 for t in range(d_true.shape[1]): - mae, rmse, mape = all_metrics( - d_pred[:, t], d_true[:, t], - self.args["mae_thresh"], - self.args["mape_thresh"] - ) - self.logger.info( - f"Horizon {t+1:02d} MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}" - ) + mae, rmse, mape = all_metrics(d_pred[:, t], d_true[:, t], self.args["mae_thresh"], self.args["mape_thresh"]) + self.logger.info(f"Horizon {t+1:02d} MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}") - avg_mae, avg_rmse, avg_mape = all_metrics( - d_pred, d_true, - self.args["mae_thresh"], - self.args["mape_thresh"] - ) - self.logger.info( - f"AVG MAE:{avg_mae:.4f} AVG RMSE:{avg_rmse:.4f} AVG MAPE:{avg_mape:.4f}" - ) + avg_mae, avg_rmse, avg_mape = all_metrics(d_pred, d_true, self.args["mae_thresh"], self.args["mape_thresh"]) + self.logger.info(f"AVG MAE:{avg_mae:.4f} AVG RMSE:{avg_rmse:.4f} AVG MAPE:{avg_mape:.4f}")