""" Base class for tensorflow models for traffic forecasting. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import tensorflow as tf class TFModel(object): def __init__(self, config, scaler=None, **kwargs): """ Initialization including placeholders, learning rate, :param config: :param scaler: data z-norm normalizer :param kwargs: """ self._config = dict(config) # Placeholders for input and output. self._inputs = None self._labels = None self._outputs = None # Scaler for data normalization. self._scaler = scaler # Train and loss self._loss = None self._mae = None self._train_op = None # Learning rate. learning_rate = config.get('learning_rate', 0.001) self._lr = tf.get_variable('learning_rate', shape=(), initializer=tf.constant_initializer(learning_rate), trainable=False) self._new_lr = tf.placeholder(tf.float32, shape=(), name='new_learning_rate') self._lr_update = tf.assign(self._lr, self._new_lr, name='lr_update') # Log merged summary self._merged = None @staticmethod def run_epoch(sess, model, inputs, labels, return_output=False, train_op=None, writer=None): losses = [] maes = [] outputs = [] fetches = { 'mae': model.mae, 'loss': model.loss, 'global_step': tf.train.get_or_create_global_step() } if train_op: fetches.update({ 'train_op': train_op, }) merged = model.merged if merged is not None: fetches.update({'merged': merged}) if return_output: fetches.update({ 'outputs': model.outputs }) for _, (x, y) in enumerate(zip(inputs, labels)): feed_dict = { model.inputs: x, model.labels: y, } vals = sess.run(fetches, feed_dict=feed_dict) losses.append(vals['loss']) maes.append(vals['mae']) if writer is not None and 'merged' in vals: writer.add_summary(vals['merged'], global_step=vals['global_step']) if return_output: outputs.append(vals['outputs']) results = { 'loss': np.mean(losses), 'mae': np.mean(maes) } if return_output: results['outputs'] = outputs return results def get_lr(self, sess): return np.asscalar(sess.run(self._lr)) def set_lr(self, sess, lr): sess.run(self._lr_update, feed_dict={ self._new_lr: lr }) @property def inputs(self): return self._inputs @property def labels(self): return self._labels @property def loss(self): return self._loss @property def lr(self): return self._lr @property def mae(self): return self._mae @property def merged(self): return self._merged @property def outputs(self): return self._outputs @property def train_op(self): return self._train_op