From 85257bc61ca08edf619dde5c20e32e13bdea2939 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Tue, 16 Dec 2025 16:40:40 +0800 Subject: [PATCH] =?UTF-8?q?refactor(trainer):=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E5=99=A8=E4=BB=A3=E7=A0=81=E7=BB=93=E6=9E=84?= =?UTF-8?q?=E5=B9=B6=E6=B7=BB=E5=8A=A0=E8=BF=9B=E5=BA=A6=E6=9D=A1=E6=98=BE?= =?UTF-8?q?=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 调整训练器代码结构,减少冗余代码,提高可读性 为训练过程添加tqdm进度条,实时显示loss信息 统一TRY环境变量的输出格式 简化日志记录和模型保存逻辑 --- config/ASTRA_v2/SolarEnergy.yaml | 4 +- config/GWN/METR-LA.yaml | 4 +- config/GWN/SolarEnergy.yaml | 4 +- config/MTGNN/SolarEnergy.yaml | 2 +- config/REPST/SolarEnergy.yaml | 4 +- train.py | 4 +- trainer/TSTrainer.py | 163 +++++++------------------------ trainer/Trainer.py | 8 +- 8 files changed, 52 insertions(+), 141 deletions(-) diff --git a/config/ASTRA_v2/SolarEnergy.yaml b/config/ASTRA_v2/SolarEnergy.yaml index 83a87c2..9b6a223 100644 --- a/config/ASTRA_v2/SolarEnergy.yaml +++ b/config/ASTRA_v2/SolarEnergy.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 64 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: word_num: 1000 train: - batch_size: 64 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/GWN/METR-LA.yaml b/config/GWN/METR-LA.yaml index ef38574..fc93634 100644 --- a/config/GWN/METR-LA.yaml +++ b/config/GWN/METR-LA.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 64 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -40,7 +40,7 @@ model: supports: null train: - batch_size: 64 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/GWN/SolarEnergy.yaml b/config/GWN/SolarEnergy.yaml index cd1d043..4e572fa 100644 --- a/config/GWN/SolarEnergy.yaml +++ b/config/GWN/SolarEnergy.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 64 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -40,7 +40,7 @@ model: supports: null train: - batch_size: 64 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/MTGNN/SolarEnergy.yaml b/config/MTGNN/SolarEnergy.yaml index 2f60b8d..57e17c8 100644 --- a/config/MTGNN/SolarEnergy.yaml +++ b/config/MTGNN/SolarEnergy.yaml @@ -10,7 +10,7 @@ data: column_wise: false days_per_week: 7 horizon: 24 - input_dim: 6 + input_dim: 1 lag: 24 normalizer: std num_nodes: 137 diff --git a/config/REPST/SolarEnergy.yaml b/config/REPST/SolarEnergy.yaml index dd4579e..a96e58a 100755 --- a/config/REPST/SolarEnergy.yaml +++ b/config/REPST/SolarEnergy.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 64 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: word_num: 1000 train: - batch_size: 64 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/train.py b/train.py index fcaaa6a..b5b42b5 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,7 @@ def read_config(config_path): # 全局配置 device = "cuda:0" # 指定设备为cuda:0 seed = 2023 # 随机种子 - epochs = 100 # 训练轮数 + epochs = 1 # 训练轮数 # 拷贝项 config["basic"]["device"] = device @@ -91,7 +91,7 @@ if __name__ == "__main__": # 调试用 model_list = ["iTransformer", "PatchTST", "HI"] # model_list = ["ASTRA_v2", "GWN", "REPST", "STAEFormer", "MTGNN"] - # model_list = ["iTransformer"] + # model_list = ["MTGNN"] # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] # dataset_list = ["AirQuality"] dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-InFlow", "NYCBike-OutFlow"] diff --git a/trainer/TSTrainer.py b/trainer/TSTrainer.py index 81ee54f..3ddf361 100755 --- a/trainer/TSTrainer.py +++ b/trainer/TSTrainer.py @@ -5,86 +5,46 @@ from utils.loss_function import all_metrics class Trainer: - def __init__(self, model, loss, optimizer, - train_loader, val_loader, test_loader, - scaler, args, lr_scheduler=None): - - self.config = args - self.device = args["basic"]["device"] - self.args = args["train"] - - self.model = model.to(self.device) - self.loss = loss - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - - self.train_loader = train_loader - self.val_loader = val_loader or test_loader - self.test_loader = test_loader + def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler=None): + self.device, self.args = args["basic"]["device"], args["train"] + self.model, self.loss, self.optimizer, self.lr_scheduler = model.to(self.device), loss, optimizer, lr_scheduler + self.train_loader, self.val_loader, self.test_loader = train_loader, val_loader or test_loader, test_loader self.scaler = scaler - - # ---------- shape magic (replace TSWrapper) ---------- - self.pack = lambda x: ( - x[..., :-2] - .permute(0, 2, 1, 3) - .reshape(-1, x.size(1), x.size(3) - 2), - x.shape - ) - self.unpack = lambda y, s: ( - y.reshape(s[0], s[2], s[1], -1) - .permute(0, 2, 1, 3) - ) - - # ---------- inverse scaler ---------- - self.inv = lambda x: torch.cat( - [s.inverse_transform(x[..., i:i+1]) for i, s in enumerate(self.scaler)], - dim=-1 - ) - + self.inv = lambda x: torch.cat([s.inverse_transform(x[..., i:i+1]) for i, s in enumerate(self.scaler)], dim=-1) # 对每个维度调用反归一化器后cat self._init_paths() self._init_logger() + # ---------- shape magic (replace TSWrapper) ---------- + self.pack = lambda x:(x[..., :-2].permute(0, 2, 1, 3).reshape(-1, x.size(1), x.size(3) - 2), x.shape) + self.unpack = lambda y, s: (y.reshape(s[0], s[2], s[1], -1).permute(0, 2, 1, 3)) # ---------------- init ---------------- def _init_paths(self): d = self.args["log_dir"] - self.best_path = os.path.join(d, "best_model.pth") - self.best_test_path = os.path.join(d, "best_test_model.pth") + self.best_path, self.best_test_path = os.path.join(d, "best_model.pth"), os.path.join(d, "best_test_model.pth") def _init_logger(self): - if not self.args["debug"]: - os.makedirs(self.args["log_dir"], exist_ok=True) - self.logger = get_logger( - self.args["log_dir"], - name=self.model.__class__.__name__, - debug=self.args["debug"] - ) + if not self.args["debug"]: os.makedirs(self.args["log_dir"], exist_ok=True) + self.logger = get_logger(self.args["log_dir"], name=self.model.__class__.__name__, debug=self.args["debug"]) # ---------------- epoch ---------------- def _run_epoch(self, epoch, loader, mode): is_train = mode == "train" self.model.train() if is_train else self.model.eval() - total_loss, start = 0.0, time.time() y_pred, y_true = [], [] with torch.set_grad_enabled(is_train): - for data, target in tqdm(loader, desc=f"{mode} {epoch}", total=len(loader)): + bar = tqdm(loader, desc=f"{mode} {epoch}", total=len(loader)) + for data, target in bar: data, target = data.to(self.device), target.to(self.device) label = target[..., :self.args["output_dim"]] - x, shp = self.pack(data) out = self.unpack(self.model(x), shp) - - if os.environ.get("TRY") == "True": - print(f"out:{out.shape} label:{label.shape}", - "✅" if out.shape == label.shape else "❌") - assert False - + if os.environ.get("TRY") == "True": print(f"{'[✅]' if out.shape == label.shape else '❌'} " + f"out: {out.shape}, label: {label.shape} \n"); assert False loss = self.loss(out, label) - - d_out, d_lbl = self.inv(out), self.inv(label) + d_out, d_lbl = self.inv(out), self.inv(label) # 反归一化 d_loss = self.loss(d_out, d_lbl) - total_loss += d_loss.item() y_pred.append(d_out.detach().cpu()) y_true.append(d_lbl.detach().cpu()) @@ -92,70 +52,38 @@ class Trainer: if is_train and self.optimizer: self.optimizer.zero_grad() loss.backward() - if self.args["grad_norm"]: - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), - self.args["max_grad_norm"] - ) + if self.args["grad_norm"]: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) self.optimizer.step() + bar.set_postfix({"loss": f"{d_loss.item():.4f}"}) y_pred, y_true = torch.cat(y_pred), torch.cat(y_true) - mae, rmse, mape = all_metrics( - y_pred, y_true, - self.args["mae_thresh"], - self.args["mape_thresh"] - ) - - self.logger.info( - f"Epoch #{epoch:02d} {mode:<5} " - f"MAE:{mae:5.2f} RMSE:{rmse:5.2f} " - f"MAPE:{mape:7.4f} " - f"Time:{time.time() - start:.2f}s" - ) - + mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) + self.logger.info(f"Epoch #{epoch:02d} {mode:<5} MAE:{mae:5.2f} RMSE:{rmse:5.2f} MAPE:{mape:7.4f} Time:{time.time()-start:.2f}s") return total_loss / len(loader) # ---------------- train ---------------- def train(self): - best = best_test = float("inf") - best_w = best_test_w = None + best, best_test = float("inf"), float("inf") + best_w, best_test_w = None, None patience = 0 - self.logger.info("Training started") for epoch in range(1, self.args["epochs"] + 1): losses = { - k: self._run_epoch(epoch, l, k) - for k, l in [ - ("train", self.train_loader), - ("val", self.val_loader), - ("test", self.test_loader) - ] + "train": self._run_epoch(epoch, self.train_loader, "train"), + "val": self._run_epoch(epoch, self.val_loader, "val"), + "test": self._run_epoch(epoch, self.test_loader, "test"), } - if losses["train"] > 1e6: - self.logger.warning("Gradient explosion detected") - break - - if losses["val"] < best: - best, patience = losses["val"], 0 - best_w = copy.deepcopy(self.model.state_dict()) - self.logger.info("Best validation model saved") - else: - patience += 1 - - if self.args["early_stop"] and patience == self.args["early_stop_patience"]: - self.logger.info("Early stopping triggered") - break - - if losses["test"] < best_test: - best_test = losses["test"] - best_test_w = copy.deepcopy(self.model.state_dict()) + if losses["train"] > 1e6: self.logger.warning("Gradient explosion detected"); break + if losses["val"] < best: best, patience, best_w = losses["val"], 0, copy.deepcopy(self.model.state_dict()) + else: patience += 1 + if self.args["early_stop"] and patience == self.args["early_stop_patience"]: break + if losses["test"] < best_test: best_test, best_test_w = losses["test"], copy.deepcopy(self.model.state_dict()) if not self.args["debug"]: torch.save(best_w, self.best_path) torch.save(best_test_w, self.best_test_path) - self._final_test(best_w, best_test_w) # ---------------- final test ---------------- @@ -174,32 +102,13 @@ class Trainer: for data, target in self.test_loader: data, target = data.to(self.device), target.to(self.device) label = target[..., :self.args["output_dim"]] - - x, shp = self.pack(data) - out = self.unpack(self.model(x), shp) - - y_pred.append(out.cpu()) + y_pred.append(self.model(data).cpu()) y_true.append(label.cpu()) - d_pred = self.inv(torch.cat(y_pred)) - d_true = self.inv(torch.cat(y_true)) - + d_pred, d_true = self.inv(torch.cat(y_pred)), self.inv(torch.cat(y_true)) # 反归一化 for t in range(d_true.shape[1]): - mae, rmse, mape = all_metrics( - d_pred[:, t], d_true[:, t], - self.args["mae_thresh"], - self.args["mape_thresh"] - ) - self.logger.info( - f"Horizon {t+1:02d} " - f"MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}" - ) + mae, rmse, mape = all_metrics(d_pred[:, t], d_true[:, t], self.args["mae_thresh"], self.args["mape_thresh"]) + self.logger.info(f"Horizon {t+1:02d} MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}") - mae, rmse, mape = all_metrics( - d_pred, d_true, - self.args["mae_thresh"], - self.args["mape_thresh"] - ) - self.logger.info( - f"AVG MAE:{mae:.4f} AVG RMSE:{rmse:.4f} AVG MAPE:{mape:.4f}" - ) + avg_mae, avg_rmse, avg_mape = all_metrics(d_pred, d_true, self.args["mae_thresh"], self.args["mape_thresh"]) + self.logger.info(f"AVG MAE:{avg_mae:.4f} AVG RMSE:{avg_rmse:.4f} AVG MAPE:{avg_mape:.4f}") diff --git a/trainer/Trainer.py b/trainer/Trainer.py index e0838b5..7036e26 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -30,12 +30,13 @@ class Trainer: y_pred, y_true = [], [] with torch.set_grad_enabled(is_train): - for data, target in tqdm(loader, desc=f"{mode} {epoch}", total=len(loader)): + bar = tqdm(loader, desc=f"{mode} {epoch}", total=len(loader)) + for data, target in bar: data, target = data.to(self.device), target.to(self.device) label = target[..., :self.args["output_dim"]] out = self.model(data) - if os.environ.get("TRY") == "True": print(f"out: {out.shape}, label: {label.shape} \ - {'✅' if out.shape == label.shape else '❌'}"); assert False + if os.environ.get("TRY") == "True": print(f"{'[✅]' if out.shape == label.shape else '❌'} " + f"out: {out.shape}, label: {label.shape} \n"); assert False loss = self.loss(out, label) d_out, d_lbl = self.inv(out), self.inv(label) # 反归一化 d_loss = self.loss(d_out, d_lbl) @@ -48,6 +49,7 @@ class Trainer: loss.backward() if self.args["grad_norm"]: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) self.optimizer.step() + bar.set_postfix({"loss": f"{d_loss.item():.4f}"}) y_pred, y_true = torch.cat(y_pred), torch.cat(y_true) mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"])