From c876cbfba38b54f994709cc9019c38633e76ac4b Mon Sep 17 00:00:00 2001 From: Chintan Shah Date: Wed, 2 Oct 2019 17:34:07 -0400 Subject: [PATCH] Setup training loop and logging --- model/pytorch/dcrnn_supervisor.py | 98 +++++++++++++++++++++++++------ 1 file changed, 81 insertions(+), 17 deletions(-) diff --git a/model/pytorch/dcrnn_supervisor.py b/model/pytorch/dcrnn_supervisor.py index c10cb91..0194a8f 100644 --- a/model/pytorch/dcrnn_supervisor.py +++ b/model/pytorch/dcrnn_supervisor.py @@ -1,37 +1,78 @@ +import os +import time + import numpy as np import torch +from lib import utils from model.pytorch.dcrnn_model import EncoderModel, DecoderModel class DCRNNSupervisor: - def __init__(self, adj_mx, **kwargs): + def __init__(self, adj_mx, encoder_model: EncoderModel, decoder_model: DecoderModel, **kwargs): + self.decoder_model = decoder_model + self.encoder_model = encoder_model self._kwargs = kwargs self._data_kwargs = kwargs.get('data') self._model_kwargs = kwargs.get('model') self._train_kwargs = kwargs.get('train') + # logging. + self._log_dir = self._get_log_dir(kwargs) + log_level = self._kwargs.get('log_level', 'INFO') + self._logger = utils.get_logger(self._log_dir, __name__, 'info.log', level=log_level) + + # data set + self._data = utils.load_dataset(**self._data_kwargs) + self.standard_scaler = self._data['scaler'] + self.num_nodes = int(self._model_kwargs.get('num_nodes', 1)) self.input_dim = int(self._model_kwargs.get('input_dim', 1)) self.seq_len = int(self._model_kwargs.get('seq_len')) # for the encoder self.output_dim = int(self._model_kwargs.get('output_dim', 1)) + self.cl_decay_steps = int(self._model_kwargs.get('cl_decay_steps', 1000)) self.use_curriculum_learning = bool( self._model_kwargs.get('use_curriculum_learning', False)) self.horizon = int(self._model_kwargs.get('horizon', 1)) # for the decoder - def train(self, encoder_model: EncoderModel, decoder_model: DecoderModel, **kwargs): + @staticmethod + def _get_log_dir(kwargs): + log_dir = kwargs['train'].get('log_dir') + if log_dir is None: + batch_size = kwargs['data'].get('batch_size') + learning_rate = kwargs['train'].get('base_lr') + max_diffusion_step = kwargs['model'].get('max_diffusion_step') + num_rnn_layers = kwargs['model'].get('num_rnn_layers') + rnn_units = kwargs['model'].get('rnn_units') + structure = '-'.join( + ['%d' % rnn_units for _ in range(num_rnn_layers)]) + horizon = kwargs['model'].get('horizon') + filter_type = kwargs['model'].get('filter_type') + filter_type_abbr = 'L' + if filter_type == 'random_walk': + filter_type_abbr = 'R' + elif filter_type == 'dual_random_walk': + filter_type_abbr = 'DR' + run_id = 'dcrnn_%s_%d_h_%d_%s_lr_%g_bs_%d_%s/' % ( + filter_type_abbr, max_diffusion_step, horizon, + structure, learning_rate, batch_size, + time.strftime('%m%d%H%M%S')) + base_dir = kwargs.get('base_dir') + log_dir = os.path.join(base_dir, run_id) + if not os.path.exists(log_dir): + os.makedirs(log_dir) + return log_dir + + def train(self, **kwargs): kwargs.update(self._train_kwargs) return self._train(**kwargs) - def _train_one_batch(self, inputs, labels, encoder_model: EncoderModel, - decoder_model: DecoderModel, encoder_optimizer, + def _train_one_batch(self, inputs, labels, batches_seen, encoder_optimizer, decoder_optimizer, criterion): """ :param inputs: shape (seq_len, batch_size, num_sensor, input_dim) :param labels: shape (horizon, batch_size, num_sensor, input_dim) - :param encoder_model: - :param decoder_model: :param encoder_optimizer: :param decoder_optimizer: :param criterion: minimize this criterion @@ -50,7 +91,7 @@ class DCRNNSupervisor: encoder_hidden_state = None for t in range(self.seq_len): - _, encoder_hidden_state = encoder_model.forward(inputs[t], encoder_hidden_state) + _, encoder_hidden_state = self.encoder_model.forward(inputs[t], encoder_hidden_state) go_symbol = torch.zeros((batch_size, self.num_nodes * self.output_dim)) @@ -58,27 +99,50 @@ class DCRNNSupervisor: decoder_input = go_symbol for t in range(self.horizon): - decoder_output, decoder_hidden_state = decoder_model.forward(decoder_input, - decoder_hidden_state) + decoder_output, decoder_hidden_state = self.decoder_model.forward(decoder_input, + decoder_hidden_state) decoder_input = decoder_output if self.use_curriculum_learning: # todo check for is_training (pytorch way?) c = np.random.uniform(0, 1) - if c < self._compute_sampling_threshold(): + if c < self._compute_sampling_threshold(batches_seen): decoder_input = labels[t] - loss += criterion(decoder_output, labels[t]) + loss += criterion(self.standard_scaler.inverse_transform(decoder_output), + self.standard_scaler.inverse_transform(labels[t])) loss.backward() encoder_optimizer.step() decoder_optimizer.step() return loss.item() - def _train(self, encoder_model: EncoderModel, decoder_model: DecoderModel, base_lr, epoch, + def _train(self, base_lr, steps, patience=50, epochs=100, - min_learning_rate=2e-6, lr_decay_ratio=0.1, save_model=1, - test_every_n_epochs=10): - pass + min_learning_rate=2e-6, lr_decay_ratio=0.1, log_every=10, save_model=1, + test_every_n_epochs=10, **kwargs): + # steps is used in learning rate - will see if need to use it? + encoder_optimizer = torch.optim.Adam(self.encoder_model.parameters(), lr=base_lr) + decoder_optimizer = torch.optim.Adam(self.encoder_model.parameters(), lr=base_lr) + criterion = torch.nn.L1Loss() # mae loss - def _compute_sampling_threshold(self): - return 1.0 # todo + batches_seen = 0 + self._logger.info('Start training ...') + for epoch_num in range(epochs): + train_iterator = self._data['train_loader'].get_iterator() + losses = [] + + start_time = time.time() + + for x, y in train_iterator: + loss = self._train_one_batch(x, y, batches_seen, encoder_optimizer, decoder_optimizer, criterion) + losses.append(loss) + batches_seen += 1 + + end_time = time.time() + if epoch_num % log_every == 0: + message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f} lr:{:.6f} {:.1f}s'.format( + epoch_num, epochs, batches_seen, np.mean(losses), 0.0, 0.0, (end_time - start_time)) + self._logger.info(message) + + def _compute_sampling_threshold(self, batches_seen): + return self.cl_decay_steps / (self.cl_decay_steps + np.exp(batches_seen / self.cl_decay_steps))