import os, time, copy, torch from tqdm import tqdm from utils.logger import get_logger from utils.loss_function import all_metrics class Trainer: def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler=None): self.device, self.args = args["basic"]["device"], args["train"] self.model, self.loss, self.optimizer, self.lr_scheduler = model.to(self.device), loss, optimizer, lr_scheduler self.train_loader, self.val_loader, self.test_loader = train_loader, val_loader or test_loader, test_loader self.scaler = scaler self.inv = lambda x: torch.cat([s.inverse_transform(x[..., i:i+1]) for i, s in enumerate(self.scaler)], dim=-1) # 对每个维度调用反归一化器后cat self._init_paths() self._init_logger() # ---------- shape magic (replace TSWrapper) ---------- self.pack = lambda x:(x[..., :-2].permute(0, 2, 1, 3).reshape(-1, x.size(1), x.size(3) - 2), x.shape) self.unpack = lambda y, s: (y.reshape(s[0], s[2], s[1], -1).permute(0, 2, 1, 3)) # ---------------- init ---------------- def _init_paths(self): d = self.args["log_dir"] self.best_path, self.best_test_path = os.path.join(d, "best_model.pth"), os.path.join(d, "best_test_model.pth") def _init_logger(self): if not self.args["debug"]: os.makedirs(self.args["log_dir"], exist_ok=True) self.logger = get_logger(self.args["log_dir"], name=self.model.__class__.__name__, debug=self.args["debug"]) # ---------------- epoch ---------------- def _run_epoch(self, epoch, loader, mode): is_train = mode == "train" self.model.train() if is_train else self.model.eval() total_loss, start = 0.0, time.time() y_pred, y_true = [], [] with torch.set_grad_enabled(is_train): bar = tqdm(loader, desc=f"{mode} {epoch}", total=len(loader)) for data, target in bar: data, target = data.to(self.device), target.to(self.device) label = target[..., :self.args["output_dim"]] x, shp = self.pack(data) out = self.unpack(self.model(x), shp) if os.environ.get("TRY") == "True": print(f"{'[✅]' if out.shape == label.shape else '❌'} " f"out: {out.shape}, label: {label.shape} \n"); assert False loss = self.loss(out, label) d_out, d_lbl = self.inv(out), self.inv(label) # 反归一化 d_loss = self.loss(d_out, d_lbl) total_loss += d_loss.item() y_pred.append(d_out.detach().cpu()) y_true.append(d_lbl.detach().cpu()) if is_train and self.optimizer: self.optimizer.zero_grad() loss.backward() if self.args["grad_norm"]: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"]) self.optimizer.step() bar.set_postfix({"loss": f"{d_loss.item():.4f}"}) y_pred, y_true = torch.cat(y_pred), torch.cat(y_true) mae, rmse, mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"]) self.logger.info(f"Epoch #{epoch:02d} {mode:<5} MAE:{mae:5.2f} RMSE:{rmse:5.2f} MAPE:{mape:7.4f} Time:{time.time()-start:.2f}s") return total_loss / len(loader) # ---------------- train ---------------- def train(self): best, best_test = float("inf"), float("inf") best_w, best_test_w = None, None patience = 0 self.logger.info("Training started") for epoch in range(1, self.args["epochs"] + 1): losses = { "train": self._run_epoch(epoch, self.train_loader, "train"), "val": self._run_epoch(epoch, self.val_loader, "val"), "test": self._run_epoch(epoch, self.test_loader, "test"), } if losses["train"] > 1e6: self.logger.warning("Gradient explosion detected"); break if losses["val"] < best: best, patience, best_w = losses["val"], 0, copy.deepcopy(self.model.state_dict()) self.logger.info(f"Best model updated at Epoch {epoch:02d}#") else: patience += 1 if self.args["early_stop"] and patience == self.args["early_stop_patience"]: break if losses["test"] < best_test: best_test, best_test_w = losses["test"], copy.deepcopy(self.model.state_dict()) self.logger.info(f"Best test model saved at Epoch {epoch:02d}#") if not self.args["debug"]: torch.save(best_w, self.best_path) torch.save(best_test_w, self.best_test_path) self._final_test(best_w, best_test_w) # ---------------- final test ---------------- def _final_test(self, best_w, best_test_w): for name, w in [("best val", best_w), ("best test", best_test_w)]: self.model.load_state_dict(w) self.logger.info(f"Testing on {name} model") self.evaluate() # ---------------- evaluate ---------------- def evaluate(self): self.model.eval() y_pred, y_true = [], [] with torch.no_grad(): for data, target in self.test_loader: data, target = data.to(self.device), target.to(self.device) label = target[..., :self.args["output_dim"]] x, shp = self.pack(data) out = self.unpack(self.model(x), shp) y_pred.append(out.cpu()) y_true.append(label.cpu()) d_pred, d_true = self.inv(torch.cat(y_pred)), self.inv(torch.cat(y_true)) # 反归一化 for t in range(d_true.shape[1]): mae, rmse, mape = all_metrics(d_pred[:, t], d_true[:, t], self.args["mae_thresh"], self.args["mape_thresh"]) self.logger.info(f"Horizon {t+1:02d} MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}") avg_mae, avg_rmse, avg_mape = all_metrics(d_pred, d_true, self.args["mae_thresh"], self.args["mape_thresh"]) self.logger.info(f"AVG MAE:{avg_mae:.4f} AVG RMSE:{avg_rmse:.4f} AVG MAPE:{avg_mape:.4f}")