import os, time, copy, torch from tqdm import tqdm from utils.logger import get_logger from utils.loss_function import all_metrics class TSWrapper: def __init__(self, args): self.n = args['data']['num_nodes'] def forward(self, x): # [b, t, n, c] -> [b*n, t, c] b, t, n, c = x.shape x = x[..., :-2].permute(0, 2, 1, 3).reshape(b * n, t, c-2) return x, b, t, n, c def inverse(self, x, b, t, n, c): return x.reshape(b, n, t, c-2).permute(0, 2, 1, 3) class Trainer: def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler=None): self.config = args self.device = args["basic"]["device"] self.args = args["train"] self.model = model.to(self.device) self.loss = loss self.optimizer = optimizer self.lr_scheduler = lr_scheduler self.train_loader = train_loader self.val_loader = val_loader or test_loader self.test_loader = test_loader self.scaler = scaler self.ts = TSWrapper(args) self._init_paths() self._init_logger() # ---------------- init ---------------- def _init_paths(self): d = self.args["log_dir"] self.best_path = os.path.join(d, "best_model.pth") self.best_test_path = os.path.join(d, "best_test_model.pth") def _init_logger(self): if not self.args["debug"]: os.makedirs(self.args["log_dir"], exist_ok=True) self.logger = get_logger( self.args["log_dir"], name=self.model.__class__.__name__, debug=self.args["debug"], ) # ---------------- epoch ---------------- def _run_epoch(self, epoch, loader, mode): is_train = mode == "train" self.model.train() if is_train else self.model.eval() total_loss, start = 0.0, time.time() y_pred, y_true = [], [] with torch.set_grad_enabled(is_train): for data, target in tqdm(loader, desc=f"{mode} {epoch}", total=len(loader)): data, target = data.to(self.device), target.to(self.device) label = target[..., :self.args["output_dim"]] x, b, t, n, c = self.ts.forward(data) out = self.model(x) out = self.ts.inverse(out, b, t, n, c) if os.environ.get("TRY") == "True": print(out.shape, label.shape) assert out.shape == label.shape loss = self.loss(out, label) d_out = self.scaler.inverse_transform(out) d_lbl = self.scaler.inverse_transform(label) d_loss = self.loss(d_out, d_lbl) total_loss += d_loss.item() y_pred.append(d_out.detach().cpu()) y_true.append(d_lbl.detach().cpu()) if is_train and self.optimizer: self.optimizer.zero_grad() loss.backward() if self.args["grad_norm"]: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args["max_grad_norm"] ) self.optimizer.step() y_pred = torch.cat(y_pred) y_true = torch.cat(y_true) mae, rmse, mape = all_metrics( y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"] ) self.logger.info( f"Epoch #{epoch:02d} {mode:<5} " f"MAE:{mae:5.2f} RMSE:{rmse:5.2f} " f"MAPE:{mape:7.4f} Time:{time.time()-start:.2f}s" ) return total_loss / len(loader) # ---------------- train ---------------- def train(self): best, best_test = float("inf"), float("inf") best_w, best_test_w = None, None patience = 0 self.logger.info("Training started") for epoch in range(1, self.args["epochs"] + 1): losses = { "train": self._run_epoch(epoch, self.train_loader, "train"), "val": self._run_epoch(epoch, self.val_loader, "val"), "test": self._run_epoch(epoch, self.test_loader, "test"), } if losses["train"] > 1e6: self.logger.warning("Gradient explosion detected") break if losses["val"] < best: best, patience = losses["val"], 0 best_w = copy.deepcopy(self.model.state_dict()) self.logger.info("Best validation model saved") else: patience += 1 if self.args["early_stop"] and patience == self.args["early_stop_patience"]: self.logger.info("Early stopping triggered") break if losses["test"] < best_test: best_test = losses["test"] best_test_w = copy.deepcopy(self.model.state_dict()) if not self.args["debug"]: torch.save(best_w, self.best_path) torch.save(best_test_w, self.best_test_path) self._final_test(best_w, best_test_w) # ---------------- final test ---------------- def _final_test(self, best_w, best_test_w): for name, w in [("best val", best_w), ("best test", best_test_w)]: self.model.load_state_dict(w) self.logger.info(f"Testing on {name} model") self.evaluate() # ---------------- evaluate ---------------- def evaluate(self): self.model.eval() y_pred, y_true = [], [] with torch.no_grad(): for data, target in self.test_loader: data, target = data.to(self.device), target.to(self.device) label = target[..., :self.args["output_dim"]] x, b, t, n, c = self.ts.forward(data) out = self.model(x) out = self.ts.inverse(out, b, t, n, c) y_pred.append(out.cpu()) y_true.append(label.cpu()) d_pred = self.scaler.inverse_transform(torch.cat(y_pred)) d_true = self.scaler.inverse_transform(torch.cat(y_true)) for t in range(d_true.shape[1]): mae, rmse, mape = all_metrics( d_pred[:, t], d_true[:, t], self.args["mae_thresh"], self.args["mape_thresh"] ) self.logger.info( f"Horizon {t+1:02d} MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}" ) avg_mae, avg_rmse, avg_mape = all_metrics(d_pred, d_true, self.args["mae_thresh"], self.args["mape_thresh"]) self.logger.info( f"AVG MAE:{avg_mae:.4f} AVG RMSE:{avg_rmse:.4f} AVG MAPE:{avg_mape:.4f}" )