import torch import math import os import time import copy import numpy as np from lib.logger import get_logger from lib.metrics import All_Metrics from lib.TrainInits import print_model_parameters from lib.training_stats import TrainingStats class Trainer(object): def __init__( self, model, vector_field_f, vector_field_g, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, device, times, w, ): super(Trainer, self).__init__() self.model = model self.vector_field_f = vector_field_f self.vector_field_g = vector_field_g self.loss = loss self.optimizer = optimizer self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader self.scaler = scaler self.args = args self.lr_scheduler = lr_scheduler self.train_per_epoch = len(train_loader) if val_loader != None: self.val_per_epoch = len(val_loader) self.best_path = os.path.join(self.args.log_dir, "best_model.pth") self.loss_figure_path = os.path.join(self.args.log_dir, "loss.png") # log if os.path.isdir(args.log_dir) == False and not args.debug: os.makedirs(args.log_dir, exist_ok=True) self.logger = get_logger(args.log_dir, name=args.model, debug=args.debug) self.logger.info("Experiment log path in: {}".format(args.log_dir)) total_param = print_model_parameters(model, only_num=False) for arg, value in sorted(vars(args).items()): self.logger.info("Argument %s: %r", arg, value) self.logger.info(self.model) self.logger.info("Total params: {}".format(str(total_param))) self.device = device self.times = times.to(self.device, dtype=torch.float) self.w = w # Stats tracker self.stats = TrainingStats(device=device) def val_epoch(self, epoch, val_dataloader): self.model.eval() total_val_loss = 0 with torch.no_grad(): for batch_idx, batch in enumerate(self.val_loader): start_time = time.time() # for iter, batch in enumerate(val_dataloader): batch = tuple(b.to(self.device, dtype=torch.float) for b in batch) *valid_coeffs, target = batch # data = data[..., :self.args.input_dim] label = target[..., : self.args.output_dim] output = self.model(self.times, valid_coeffs) if self.args.real_value: label = self.scaler.inverse_transform(label) loss = self.loss(output.cuda(), label) # a whole batch of Metr_LA is filtered if not torch.isnan(loss): total_val_loss += loss.item() step_time = time.time() - start_time self.stats.record_step_time(step_time, "val") val_loss = total_val_loss / len(val_dataloader) self.logger.info( "**********Val Epoch {}: average Loss: {:.6f}".format(epoch, val_loss) ) self.stats.record_memory_usage() if self.args.tensorboard: self.w.add_scalar(f"valid/loss", val_loss, epoch) return val_loss def train_epoch(self, epoch): self.model.train() total_loss = 0 # for batch_idx, (data, target) in enumerate(self.train_loader): # for batch_idx, (data, target) in enumerate(self.train_loader): for batch_idx, batch in enumerate(self.train_loader): start_time = time.time() batch = tuple(b.to(self.device, dtype=torch.float) for b in batch) *train_coeffs, target = batch # data = data[..., :self.args.input_dim] label = target[..., : self.args.output_dim] # (..., 1) self.optimizer.zero_grad() # #teacher_forcing for RNN encoder-decoder model # #if teacher_forcing_ratio = 1: use label as input in the decoder for all steps # if self.args.teacher_forcing: # global_step = (epoch - 1) * self.train_per_epoch + batch_idx # teacher_forcing_ratio = self._compute_sampling_threshold(global_step, self.args.tf_decay_steps) # else: # teacher_forcing_ratio = 1. # data and target shape: B, T, N, F; output shape: B, T, N, F output = self.model(self.times, train_coeffs) # output = self.model(train_coeffs, target, teacher_forcing_ratio=teacher_forcing_ratio) if self.args.real_value: label = self.scaler.inverse_transform(label) loss = self.loss(output.cuda(), label) # loss = _add_weight_regularisation(loss, self.vector_field_g) #TODO: regularization # loss = _add_weight_regularisation(loss, self.vector_field_f) #TODO: regularization loss.backward() # add max grad clipping if self.args.grad_norm: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.args.max_grad_norm ) self.optimizer.step() step_time = time.time() - start_time self.stats.record_step_time(step_time, "train") total_loss += loss.item() # log information if batch_idx % self.args.log_step == 0: self.logger.info( "Train Epoch {}: {}/{} Loss: {:.6f}".format( epoch, batch_idx, self.train_per_epoch, loss.item() ) ) train_epoch_loss = total_loss / self.train_per_epoch self.logger.info( "**********Train Epoch {}: averaged Loss: {:.6f}".format( epoch, train_epoch_loss ) ) self.stats.record_memory_usage() if self.args.tensorboard: self.w.add_scalar(f"train/loss", train_epoch_loss, epoch) # learning rate decay if self.args.lr_decay: self.lr_scheduler.step() return train_epoch_loss def train(self): best_model = None best_loss = float("inf") not_improved_count = 0 train_loss_list = [] val_loss_list = [] self.stats.start_training() start_time = time.time() for epoch in range(1, self.args.epochs + 1): # epoch_time = time.time() train_epoch_loss = self.train_epoch(epoch) # print(time.time()-epoch_time) # exit() if self.val_loader == None: val_dataloader = self.test_loader else: val_dataloader = self.val_loader val_epoch_loss = self.val_epoch(epoch, val_dataloader) # print('LR:', self.optimizer.param_groups[0]['lr']) train_loss_list.append(train_epoch_loss) val_loss_list.append(val_epoch_loss) if train_epoch_loss > 1e6: self.logger.warning("Gradient explosion detected. Ending...") break # if self.val_loader == None: # val_epoch_loss = train_epoch_loss if val_epoch_loss < best_loss: best_loss = val_epoch_loss not_improved_count = 0 best_state = True else: not_improved_count += 1 best_state = False # early stop if self.args.early_stop: if not_improved_count == self.args.early_stop_patience: self.logger.info( "Validation performance didn't improve for {} epochs. " "Training stops.".format(self.args.early_stop_patience) ) break # save the best state if best_state == True: self.logger.info( "*********************************Current best model saved!" ) best_model = copy.deepcopy(self.model.state_dict()) # if epoch%10==0:#test # self.model.load_state_dict(best_model) # #self.val_epoch(self.args.epochs, self.test_loader) # self.test_simple(self.model, self.args, self.test_loader, self.scaler, self.logger, None, self.times) training_time = time.time() - start_time self.logger.info( "Total training time: {:.4f}min, best loss: {:.6f}".format( (training_time / 60), best_loss ) ) self.stats.end_training() self.stats.report(self.logger) try: total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad ) self.logger.info(f"Trainable params: {total_params}") except Exception: pass # save the best model to file if not self.args.debug: torch.save(best_model, self.best_path) self.logger.info("Saving current best model to " + self.best_path) # if epoch==10:#test # self.model.load_state_dict(best_model) # #self.val_epoch(self.args.epochs, self.test_loader) # self.test(self.model, self.args, self.test_loader, self.scaler, self.logger, None, self.times) self.model.load_state_dict(best_model) # self.val_epoch(self.args.epochs, self.test_loader) self.test( self.model, self.args, self.test_loader, self.scaler, self.logger, None, self.times, ) def save_checkpoint(self): state = { "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "config": self.args, } torch.save(state, self.best_path) self.logger.info("Saving current best model to " + self.best_path) @staticmethod def test(model, args, data_loader, scaler, logger, path, times): if path != None: check_point = torch.load(path) state_dict = check_point["state_dict"] args = check_point["config"] model.load_state_dict(state_dict) model.to(args.device) model.eval() y_pred = [] y_true = [] with torch.no_grad(): for batch_idx, batch in enumerate(data_loader): batch = tuple(b.to(args.device, dtype=torch.float) for b in batch) *test_coeffs, target = batch label = target[..., : args.output_dim] output = model(times.to(args.device, dtype=torch.float), test_coeffs) y_true.append(label) y_pred.append(output) y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) if args.real_value: y_pred = torch.cat(y_pred, dim=0) else: y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) np.save( args.log_dir + "/{}_true.npy".format(args.dataset), y_true.cpu().numpy() ) np.save( args.log_dir + "/{}_pred.npy".format(args.dataset), y_pred.cpu().numpy() ) for t in range(y_true.shape[1]): mae, rmse, mape, _, _ = All_Metrics( y_pred[:, t, ...], y_true[:, t, ...], args.mae_thresh, args.mape_thresh ) logger.info( "Horizon {:02d}, MAE: {:.2f}, RMSE: {:.2f}, MAPE: {:.4f}%".format( t + 1, mae, rmse, mape * 100 ) ) mae, rmse, mape, _, _ = All_Metrics( y_pred, y_true, args.mae_thresh, args.mape_thresh ) logger.info( "Average Horizon, MAE: {:.2f}, RMSE: {:.2f}, MAPE: {:.4f}%".format( mae, rmse, mape * 100 ) ) @staticmethod def test_simple(model, args, data_loader, scaler, logger, path, times): if path != None: check_point = torch.load(path) state_dict = check_point["state_dict"] args = check_point["config"] model.load_state_dict(state_dict) model.to(args.device) model.eval() y_pred = [] y_true = [] with torch.no_grad(): for batch_idx, batch in enumerate(data_loader): # for batch_idx, (data, target) in enumerate(data_loader): batch = tuple(b.to(args.device, dtype=torch.float) for b in batch) *test_coeffs, target = batch # data = data[..., :args.input_dim] label = target[..., : args.output_dim] output = model(times.to(args.device, dtype=torch.float), test_coeffs) y_true.append(label) y_pred.append(output) y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) if args.real_value: y_pred = torch.cat(y_pred, dim=0) else: y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) for t in range(y_true.shape[1]): mae, rmse, mape, _, _ = All_Metrics( y_pred[:, t, ...], y_true[:, t, ...], args.mae_thresh, args.mape_thresh ) mae, rmse, mape, _, _ = All_Metrics( y_pred, y_true, args.mae_thresh, args.mape_thresh ) logger.info( "Average Horizon, MAE: {:.2f}, RMSE: {:.2f}, MAPE: {:.4f}%".format( mae, rmse, mape * 100 ) ) @staticmethod def _compute_sampling_threshold(global_step, k): """ Computes the sampling probability for scheduled sampling using inverse sigmoid. :param global_step: :param k: :return: """ return k / (k + math.exp(global_step / k)) def _add_weight_regularisation(total_loss, regularise_parameters, scaling=0.03): for parameter in regularise_parameters.parameters(): if parameter.requires_grad: total_loss = total_loss + scaling * parameter.norm() return total_loss