107 lines
5.3 KiB
Python
Executable File
107 lines
5.3 KiB
Python
Executable File
import os
|
|
import time
|
|
import copy
|
|
import torch
|
|
from utils.logger import get_logger
|
|
from utils.loss_function import all_metrics
|
|
from tqdm import tqdm
|
|
|
|
class Trainer:
|
|
def __init__(self, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler=None):
|
|
self.config, self.device, self.args = args, args["basic"]["device"], args["train"]
|
|
self.model, self.loss, self.optimizer, self.lr_scheduler = model, loss, optimizer, lr_scheduler
|
|
self.train_loader, self.val_loader, self.test_loader, self.scaler = train_loader, val_loader, test_loader, scaler
|
|
|
|
log_dir = self.args["log_dir"]
|
|
self.best_path, self.best_test_path = [os.path.join(log_dir, f"best_{suffix}_model.pth") for suffix in ["", "test"]]
|
|
|
|
if not self.args["debug"]: os.makedirs(log_dir, exist_ok=True)
|
|
self.logger = get_logger(log_dir, name=self.model.__class__.__name__, debug=self.args["debug"])
|
|
self.logger.info(f"Experiment log path in: {log_dir}")
|
|
|
|
def train(self):
|
|
best_model = best_test_model = None
|
|
best_loss = best_test_loss = float("inf")
|
|
not_improved_count = 0
|
|
|
|
self.logger.info("Training process started")
|
|
|
|
for epoch in range(1, self.args["epochs"] + 1):
|
|
train_loss = self._run_epoch(epoch, self.train_loader, "train")
|
|
val_loss = self._run_epoch(epoch, self.val_loader or self.test_loader, "val")
|
|
test_loss = self._run_epoch(epoch, self.test_loader, "test")
|
|
|
|
if train_loss > 1e6:
|
|
self.logger.warning("Gradient explosion detected. Ending...")
|
|
break
|
|
|
|
if val_loss < best_loss:
|
|
best_loss, not_improved_count, best_model = val_loss, 0, copy.deepcopy(self.model.state_dict())
|
|
self.logger.info("Best validation model saved!")
|
|
elif self.args["early_stop"] and (not_improved_count := not_improved_count + 1) == self.args["early_stop_patience"]:
|
|
self.logger.info(f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops.")
|
|
break
|
|
|
|
if test_loss < best_test_loss:
|
|
best_test_loss, best_test_model = test_loss, copy.deepcopy(self.model.state_dict())
|
|
|
|
|
|
torch.save(best_model, self.best_path)
|
|
torch.save(best_test_model, self.best_test_path)
|
|
self.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
|
|
|
for model_name, state_dict in [("best validation", best_model), ("best test", best_test_model)]:
|
|
self.model.load_state_dict(state_dict)
|
|
self.logger.info(f"Testing on {model_name} model")
|
|
self._run_epoch(None, self.test_loader, "test", log_horizon=True)
|
|
|
|
def _run_epoch(self, epoch, dataloader, mode, log_horizon=False):
|
|
self.model.train() if mode == "train" else self.model.eval()
|
|
optimizer_step = mode == "train"
|
|
|
|
total_loss, epoch_time = 0, time.time()
|
|
y_pred, y_true = [], []
|
|
|
|
with torch.set_grad_enabled(optimizer_step):
|
|
for data, target in tqdm(dataloader, total=len(dataloader), desc=f"{mode.capitalize()} Epoch {epoch}" if epoch else mode):
|
|
data, target = data.to(self.device), target.to(self.device)
|
|
label = target[..., :self.args["output_dim"]]
|
|
|
|
output = self.model(data)
|
|
loss = self.loss(output, label)
|
|
d_output, d_label = self.scaler.inverse_transform(output), self.scaler.inverse_transform(label)
|
|
d_loss = self.loss(d_output, d_label)
|
|
|
|
total_loss += d_loss.item()
|
|
y_pred.append(d_output.detach().cpu())
|
|
y_true.append(d_label.detach().cpu())
|
|
|
|
if optimizer_step and self.optimizer:
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
if self.args["grad_norm"]: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args["max_grad_norm"])
|
|
self.optimizer.step()
|
|
|
|
y_pred, y_true = torch.cat(y_pred, dim=0), torch.cat(y_true, dim=0)
|
|
|
|
if log_horizon:
|
|
for t in range(y_true.shape[1]):
|
|
mae, rmse, mape = all_metrics(y_pred[:, t, ...], y_true[:, t, ...], self.args["mae_thresh"], self.args["mape_thresh"])
|
|
self.logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
|
|
|
|
avg_mae, avg_rmse, avg_mape = all_metrics(y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"])
|
|
|
|
if epoch and mode:
|
|
self.logger.info(f"Epoch #{epoch:02d}: {mode.capitalize():<5} MAE:{avg_mae:5.2f} | RMSE:{avg_rmse:5.2f} | MAPE:{avg_mape:7.4f} | Time: {time.time()-epoch_time:.2f} s")
|
|
elif mode:
|
|
self.logger.info(f"{mode.capitalize():<5} MAE:{avg_mae:.4f} | RMSE:{avg_rmse:.4f} | MAPE:{avg_mape:.4f}")
|
|
|
|
return total_loss / len(dataloader)
|
|
|
|
def test(self, path=None):
|
|
if path:
|
|
self.model.load_state_dict(torch.load(path)["state_dict"])
|
|
self.model.to(self.device)
|
|
|
|
self._run_epoch(None, self.test_loader, "test", log_horizon=True)
|