TrafficWheel/trainer/TSTrainer.py

196 lines
6.7 KiB
Python
Executable File

import os, time, copy, torch
from tqdm import tqdm
from utils.logger import get_logger
from utils.loss_function import all_metrics
class TSWrapper:
def __init__(self, args):
self.n = args['data']['num_nodes']
def forward(self, x):
# [b, t, n, c] -> [b*n, t, c]
b, t, n, c = x.shape
x = x[..., :-2].permute(0, 2, 1, 3).reshape(b * n, t, c-2)
return x, b, t, n, c
def inverse(self, x, b, t, n, c):
return x.reshape(b, n, t, c-2).permute(0, 2, 1, 3)
class Trainer:
def __init__(self, model, loss, optimizer,
train_loader, val_loader, test_loader,
scaler, args, lr_scheduler=None):
self.config = args
self.device = args["basic"]["device"]
self.args = args["train"]
self.model = model.to(self.device)
self.loss = loss
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.train_loader = train_loader
self.val_loader = val_loader or test_loader
self.test_loader = test_loader
self.scaler = scaler
self.ts = TSWrapper(args)
self._init_paths()
self._init_logger()
# ---------------- init ----------------
def _init_paths(self):
d = self.args["log_dir"]
self.best_path = os.path.join(d, "best_model.pth")
self.best_test_path = os.path.join(d, "best_test_model.pth")
def _init_logger(self):
if not self.args["debug"]:
os.makedirs(self.args["log_dir"], exist_ok=True)
self.logger = get_logger(
self.args["log_dir"],
name=self.model.__class__.__name__,
debug=self.args["debug"],
)
# ---------------- epoch ----------------
def _run_epoch(self, epoch, loader, mode):
is_train = mode == "train"
self.model.train() if is_train else self.model.eval()
total_loss, start = 0.0, time.time()
y_pred, y_true = [], []
with torch.set_grad_enabled(is_train):
for data, target in tqdm(loader, desc=f"{mode} {epoch}", total=len(loader)):
data, target = data.to(self.device), target.to(self.device)
label = target[..., :self.args["output_dim"]]
x, b, t, n, c = self.ts.forward(data)
out = self.model(x)
out = self.ts.inverse(out, b, t, n, c)
if os.environ.get("TRY") == "True":
print(out.shape, label.shape)
assert out.shape == label.shape
loss = self.loss(out, label)
d_out = self.scaler.inverse_transform(out)
d_lbl = self.scaler.inverse_transform(label)
d_loss = self.loss(d_out, d_lbl)
total_loss += d_loss.item()
y_pred.append(d_out.detach().cpu())
y_true.append(d_lbl.detach().cpu())
if is_train and self.optimizer:
self.optimizer.zero_grad()
loss.backward()
if self.args["grad_norm"]:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.args["max_grad_norm"]
)
self.optimizer.step()
y_pred = torch.cat(y_pred)
y_true = torch.cat(y_true)
mae, rmse, mape = all_metrics(
y_pred, y_true,
self.args["mae_thresh"],
self.args["mape_thresh"]
)
self.logger.info(
f"Epoch #{epoch:02d} {mode:<5} "
f"MAE:{mae:5.2f} RMSE:{rmse:5.2f} "
f"MAPE:{mape:7.4f} Time:{time.time()-start:.2f}s"
)
return total_loss / len(loader)
# ---------------- train ----------------
def train(self):
best, best_test = float("inf"), float("inf")
best_w, best_test_w = None, None
patience = 0
self.logger.info("Training started")
for epoch in range(1, self.args["epochs"] + 1):
losses = {
"train": self._run_epoch(epoch, self.train_loader, "train"),
"val": self._run_epoch(epoch, self.val_loader, "val"),
"test": self._run_epoch(epoch, self.test_loader, "test"),
}
if losses["train"] > 1e6:
self.logger.warning("Gradient explosion detected")
break
if losses["val"] < best:
best, patience = losses["val"], 0
best_w = copy.deepcopy(self.model.state_dict())
self.logger.info("Best validation model saved")
else:
patience += 1
if self.args["early_stop"] and patience == self.args["early_stop_patience"]:
self.logger.info("Early stopping triggered")
break
if losses["test"] < best_test:
best_test = losses["test"]
best_test_w = copy.deepcopy(self.model.state_dict())
if not self.args["debug"]:
torch.save(best_w, self.best_path)
torch.save(best_test_w, self.best_test_path)
self._final_test(best_w, best_test_w)
# ---------------- final test ----------------
def _final_test(self, best_w, best_test_w):
for name, w in [("best val", best_w), ("best test", best_test_w)]:
self.model.load_state_dict(w)
self.logger.info(f"Testing on {name} model")
self.evaluate()
# ---------------- evaluate ----------------
def evaluate(self):
self.model.eval()
y_pred, y_true = [], []
with torch.no_grad():
for data, target in self.test_loader:
data, target = data.to(self.device), target.to(self.device)
label = target[..., :self.args["output_dim"]]
x, b, t, n, c = self.ts.forward(data)
out = self.model(x)
out = self.ts.inverse(out, b, t, n, c)
y_pred.append(out.cpu())
y_true.append(label.cpu())
d_pred = self.scaler.inverse_transform(torch.cat(y_pred))
d_true = self.scaler.inverse_transform(torch.cat(y_true))
for t in range(d_true.shape[1]):
mae, rmse, mape = all_metrics(
d_pred[:, t], d_true[:, t],
self.args["mae_thresh"],
self.args["mape_thresh"]
)
self.logger.info(
f"Horizon {t+1:02d} MAE:{mae:.4f} RMSE:{rmse:.4f} MAPE:{mape:.4f}"
)
avg_mae, avg_rmse, avg_mape = all_metrics(d_pred, d_true, self.args["mae_thresh"], self.args["mape_thresh"])
self.logger.info(
f"AVG MAE:{avg_mae:.4f} AVG RMSE:{avg_rmse:.4f} AVG MAPE:{avg_mape:.4f}"
)