from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import tensorflow as tf from tensorflow.contrib import legacy_seq2seq from lib.metrics import masked_mse_loss, masked_mae_loss, masked_rmse_loss from model.dcrnn_cell import DCGRUCell class DCRNNModel(object): def __init__(self, is_training, batch_size, scaler, adj_mx, **model_kwargs): # Scaler for data normalization. self._scaler = scaler # Train and loss self._loss = None self._mae = None self._train_op = None # Learning rate. self._lr = tf.get_variable('learning_rate', shape=(), initializer=tf.constant_initializer(0.01), trainable=False) self._new_lr = tf.placeholder(tf.float32, shape=(), name='new_learning_rate') self._lr_update = tf.assign(self._lr, self._new_lr, name='lr_update') max_diffusion_step = int(model_kwargs.get('max_diffusion_step', 2)) cl_decay_steps = int(model_kwargs.get('cl_decay_steps', 1000)) filter_type = model_kwargs.get('filter_type', 'laplacian') horizon = int(model_kwargs.get('horizon', 1)) max_grad_norm = float(model_kwargs.get('max_grad_norm', 5.0)) num_nodes = int(model_kwargs.get('num_nodes', 1)) num_rnn_layers = int(model_kwargs.get('num_rnn_layers', 1)) rnn_units = int(model_kwargs.get('rnn_units')) seq_len = int(model_kwargs.get('seq_len')) use_curriculum_learning = bool(model_kwargs.get('use_curriculum_learning', False)) input_dim = int(model_kwargs.get('input_dim', 1)) output_dim = int(model_kwargs.get('output_dim', 1)) aux_dim = input_dim - output_dim # Input (batch_size, timesteps, num_sensor, input_dim) self._inputs = tf.placeholder(tf.float32, shape=(batch_size, seq_len, num_nodes, input_dim), name='inputs') # Labels: (batch_size, timesteps, num_sensor, input_dim), same format with input except the temporal dimension. self._labels = tf.placeholder(tf.float32, shape=(batch_size, horizon, num_nodes, input_dim), name='labels') GO_SYMBOL = tf.zeros(shape=(batch_size, num_nodes * input_dim)) cell = DCGRUCell(rnn_units, adj_mx, max_diffusion_step=max_diffusion_step, num_nodes=num_nodes, filter_type=filter_type) cell_with_projection = DCGRUCell(rnn_units, adj_mx, max_diffusion_step=max_diffusion_step, num_nodes=num_nodes, num_proj=output_dim, filter_type=filter_type) encoding_cells = [cell] * num_rnn_layers decoding_cells = [cell] * (num_rnn_layers - 1) + [cell_with_projection] encoding_cells = tf.contrib.rnn.MultiRNNCell(encoding_cells, state_is_tuple=True) decoding_cells = tf.contrib.rnn.MultiRNNCell(decoding_cells, state_is_tuple=True) global_step = tf.train.get_or_create_global_step() # Outputs: (batch_size, timesteps, num_nodes, output_dim) with tf.variable_scope('DCRNN_SEQ'): inputs = tf.unstack(tf.reshape(self._inputs, (batch_size, seq_len, num_nodes * input_dim)), axis=1) labels = tf.unstack( tf.reshape(self._labels[..., :output_dim], (batch_size, horizon, num_nodes * output_dim)), axis=1) if aux_dim > 0: aux_info = tf.unstack(self._labels[..., output_dim:], axis=1) aux_info.insert(0, None) labels.insert(0, GO_SYMBOL) def _loop_function(prev, i): if is_training: # Return either the model's prediction or the previous ground truth in training. if use_curriculum_learning: c = tf.random_uniform((), minval=0, maxval=1.) threshold = self._compute_sampling_threshold(global_step, cl_decay_steps) result = tf.cond(tf.less(c, threshold), lambda: labels[i], lambda: prev) else: result = labels[i] else: # Return the prediction of the model in testing. result = prev if aux_dim > 0: result = tf.reshape(result, (batch_size, num_nodes, output_dim)) result = tf.concat([result, aux_info[i]], axis=-1) result = tf.reshape(result, (batch_size, num_nodes * input_dim)) return result _, enc_state = tf.contrib.rnn.static_rnn(encoding_cells, inputs, dtype=tf.float32) outputs, final_state = legacy_seq2seq.rnn_decoder(labels, enc_state, decoding_cells, loop_function=_loop_function) # Project the output to output_dim. outputs = tf.stack(outputs[:-1], axis=1) self._outputs = tf.reshape(outputs, (batch_size, horizon, num_nodes, output_dim), name='outputs') preds = self._outputs labels = self._labels[..., :output_dim] null_val = 0. self._mae = masked_mae_loss(self._scaler, null_val)(preds=preds, labels=labels) self._loss = masked_mae_loss(self._scaler, null_val)(preds=preds, labels=labels) if is_training: optimizer = tf.train.AdamOptimizer(self._lr) tvars = tf.trainable_variables() grads = tf.gradients(self._loss, tvars) grads, _ = tf.clip_by_global_norm(grads, max_grad_norm) self._train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=global_step, name='train_op') self._merged = tf.summary.merge_all() @staticmethod def _compute_sampling_threshold(global_step, k): """ Computes the sampling probability for scheduled sampling using inverse sigmoid. :param global_step: :param k: :return: """ return tf.cast(k / (k + tf.exp(global_step / k)), tf.float32) @staticmethod def run_epoch_generator(sess, model, data_generator, return_output=False, train_op=None, writer=None): losses = [] maes = [] outputs = [] fetches = { 'mae': model.mae, 'loss': model.loss, 'global_step': tf.train.get_or_create_global_step() } if train_op: fetches.update({ 'train_op': train_op, }) merged = model.merged if merged is not None: fetches.update({'merged': merged}) if return_output: fetches.update({ 'outputs': model.outputs }) for _, (x, y) in enumerate(data_generator): feed_dict = { model.inputs: x, model.labels: y, } vals = sess.run(fetches, feed_dict=feed_dict) losses.append(vals['loss']) maes.append(vals['mae']) if writer is not None and 'merged' in vals: writer.add_summary(vals['merged'], global_step=vals['global_step']) if return_output: outputs.append(vals['outputs']) results = { 'loss': np.mean(losses), 'mae': np.mean(maes) } if return_output: results['outputs'] = outputs return results def get_lr(self, sess): return np.asscalar(sess.run(self._lr)) def set_lr(self, sess, lr): sess.run(self._lr_update, feed_dict={ self._new_lr: lr }) @property def inputs(self): return self._inputs @property def labels(self): return self._labels @property def loss(self): return self._loss @property def lr(self): return self._lr @property def mae(self): return self._mae @property def merged(self): return self._merged @property def outputs(self): return self._outputs @property def train_op(self): return self._train_op