import os import time import numpy as np import torch from lib import utils from model.pytorch.dcrnn_model import DCRNNModel device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class DCRNNSupervisor: def __init__(self, adj_mx, **kwargs): self._kwargs = kwargs self._data_kwargs = kwargs.get('data') self._model_kwargs = kwargs.get('model') self._train_kwargs = kwargs.get('train') self.max_grad_norm = self._train_kwargs.get('max_grad_norm', 1.) # logging. self._log_dir = self._get_log_dir(kwargs) log_level = self._kwargs.get('log_level', 'INFO') self._logger = utils.get_logger(self._log_dir, __name__, 'info.log', level=log_level) # data set self._data = utils.load_dataset(**self._data_kwargs) self.standard_scaler = self._data['scaler'] self.num_nodes = int(self._model_kwargs.get('num_nodes', 1)) self.input_dim = int(self._model_kwargs.get('input_dim', 1)) self.seq_len = int(self._model_kwargs.get('seq_len')) # for the encoder self.output_dim = int(self._model_kwargs.get('output_dim', 1)) self.use_curriculum_learning = bool( self._model_kwargs.get('use_curriculum_learning', False)) self.horizon = int(self._model_kwargs.get('horizon', 1)) # for the decoder # setup model dcrnn_model = DCRNNModel(adj_mx, self._logger, **self._model_kwargs) self.dcrnn_model = dcrnn_model.cuda() if torch.cuda.is_available() else dcrnn_model self._logger.info("Model created") @staticmethod def _get_log_dir(kwargs): log_dir = kwargs['train'].get('log_dir') if log_dir is None: batch_size = kwargs['data'].get('batch_size') learning_rate = kwargs['train'].get('base_lr') max_diffusion_step = kwargs['model'].get('max_diffusion_step') num_rnn_layers = kwargs['model'].get('num_rnn_layers') rnn_units = kwargs['model'].get('rnn_units') structure = '-'.join( ['%d' % rnn_units for _ in range(num_rnn_layers)]) horizon = kwargs['model'].get('horizon') filter_type = kwargs['model'].get('filter_type') filter_type_abbr = 'L' if filter_type == 'random_walk': filter_type_abbr = 'R' elif filter_type == 'dual_random_walk': filter_type_abbr = 'DR' run_id = 'dcrnn_%s_%d_h_%d_%s_lr_%g_bs_%d_%s/' % ( filter_type_abbr, max_diffusion_step, horizon, structure, learning_rate, batch_size, time.strftime('%m%d%H%M%S')) base_dir = kwargs.get('base_dir') log_dir = os.path.join(base_dir, run_id) if not os.path.exists(log_dir): os.makedirs(log_dir) return log_dir def save_model(self, epoch): if not os.path.exists(self._log_dir + 'models/'): os.makedirs(self._log_dir + 'models/') config = dict(self._kwargs) config['model_state_dict'] = self.dcrnn_model.state_dict() config['epoch'] = epoch torch.save(config, self._log_dir + 'models/epo%d.tar' % epoch) self._logger.info("Loaded model at {}".format(epoch)) return self._log_dir + 'models/epo%d.tar' % epoch def load_model(self, epoch): assert os.path.exists( self._log_dir + 'models/epo%d.tar' % epoch), 'Weights at epoch %d not found' % epoch checkpoint = torch.load(self._log_dir + 'models/epo%d.tar' % epoch, map_location='cpu') self.dcrnn_model.load_state_dict(checkpoint['model_state_dict']) self._logger.info("Loaded model at {}".format(epoch)) def train(self, **kwargs): kwargs.update(self._train_kwargs) return self._train(**kwargs) def evaluate(self, dataset='val'): """ Computes mean L1Loss :return: mean L1Loss """ with torch.no_grad(): self.dcrnn_model = self.dcrnn_model.eval() val_iterator = self._data['{}_loader'.format(dataset)].get_iterator() losses = [] criterion = torch.nn.L1Loss() for _, (x, y) in enumerate(val_iterator): x, y = self._prepare_data(x, y) output = self.dcrnn_model(x) loss = self._compute_loss(y, output, criterion) losses.append(loss.item()) return np.mean(losses) def _train(self, base_lr, steps, patience=50, epochs=100, min_learning_rate=2e-6, lr_decay_ratio=0.1, log_every=10, save_model=1, test_every_n_epochs=10, **kwargs): # steps is used in learning rate - will see if need to use it? min_val_loss = float('inf') wait = 0 batches_seen = 0 optimizer = torch.optim.Adam(self.dcrnn_model.parameters(), lr=base_lr) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=steps, gamma=lr_decay_ratio) criterion = torch.nn.L1Loss() # mae loss self.dcrnn_model = self.dcrnn_model.train() self._logger.info('Start training ...') self._logger.info("num_batches:{}".format(self._data['train_loader'].num_batch)) for epoch_num in range(epochs): train_iterator = self._data['train_loader'].get_iterator() losses = [] start_time = time.time() for _, (x, y) in enumerate(train_iterator): optimizer.zero_grad() x, y = self._prepare_data(x, y) output = self.dcrnn_model(x, y, batches_seen) loss = self._compute_loss(y, output, criterion) self._logger.debug(loss.item()) losses.append(loss.item()) batches_seen += 1 loss.backward() # gradient clipping - this does it in place torch.nn.utils.clip_grad_norm_(self.dcrnn_model.parameters(), self.max_grad_norm) optimizer.step() self._logger.info("epoch complete") lr_scheduler.step() self._logger.info("evaluating now!") val_loss = self.evaluate(dataset='val') end_time = time.time() if epoch_num % log_every == 0: message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f}, lr: {:.6f}' \ '{:.1f}s'.format(epoch_num, epochs, batches_seen, np.mean(losses), val_loss, lr_scheduler.get_lr(), (end_time - start_time)) self._logger.info(message) if epoch_num % test_every_n_epochs == 0: test_loss = self.evaluate(dataset='test') message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f}, lr: {:.6f} ' \ '{:.1f}s'.format(epoch_num, epochs, batches_seen, np.mean(losses), test_loss, lr_scheduler.get_lr(), (end_time - start_time)) self._logger.info(message) if val_loss < min_val_loss: wait = 0 min_val_loss = val_loss if save_model: model_file_name = self.save_model(epoch_num) self._logger.info( 'Val loss decrease from {:.4f} to {:.4f}, ' 'saving to {}'.format(min_val_loss, val_loss, model_file_name)) elif val_loss >= min_val_loss: wait += 1 if wait == patience: self._logger.warning('Early stopping at epoch: %d' % epoch_num) break def _prepare_data(self, x, y): x, y = self._get_x_y(x, y) x, y = self._get_x_y_in_correct_dims(x, y) return x.to(device), y.to(device) def _get_x_y(self, x, y): """ :param x: shape (batch_size, seq_len, num_sensor, input_dim) :param y: shape (batch_size, horizon, num_sensor, input_dim) :returns x shape (seq_len, batch_size, num_sensor, input_dim) y shape (horizon, batch_size, num_sensor, input_dim) """ x = torch.from_numpy(x).float() y = torch.from_numpy(y).float() self._logger.debug("X: {}".format(x.size())) self._logger.debug("y: {}".format(y.size())) x = x.permute(1, 0, 2, 3) y = y.permute(1, 0, 2, 3) return x, y def _get_x_y_in_correct_dims(self, x, y): """ :param x: shape (seq_len, batch_size, num_sensor, input_dim) :param y: shape (horizon, batch_size, num_sensor, input_dim) :return: x: shape (seq_len, batch_size, num_sensor * input_dim) y: shape (horizon, batch_size, num_sensor * output_dim) """ batch_size = x.size(1) x = x.view(self.seq_len, batch_size, self.num_nodes * self.input_dim) y = y[..., :self.output_dim].view(self.horizon, batch_size, self.num_nodes * self.output_dim) return x, y def _compute_loss(self, y_true, y_predicted, criterion): loss = 0 for t in range(self.horizon): loss += criterion(self.standard_scaler.inverse_transform(y_predicted[t]), self.standard_scaler.inverse_transform(y_true[t])) return loss