DCRNN/model/tf_model.py

131 lines
3.4 KiB
Python

"""
Base class for tensorflow models for traffic forecasting.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
class TFModel(object):
def __init__(self, config, scaler=None, **kwargs):
"""
Initialization including placeholders, learning rate,
:param config:
:param scaler: data z-norm normalizer
:param kwargs:
"""
self._config = dict(config)
# Placeholders for input and output.
self._inputs = None
self._labels = None
self._outputs = None
# Scaler for data normalization.
self._scaler = scaler
# Train and loss
self._loss = None
self._mae = None
self._train_op = None
# Learning rate.
learning_rate = config.get('learning_rate', 0.001)
self._lr = tf.get_variable('learning_rate', shape=(), initializer=tf.constant_initializer(learning_rate),
trainable=False)
self._new_lr = tf.placeholder(tf.float32, shape=(), name='new_learning_rate')
self._lr_update = tf.assign(self._lr, self._new_lr, name='lr_update')
# Log merged summary
self._merged = None
@staticmethod
def run_epoch(sess, model, inputs, labels, return_output=False, train_op=None, writer=None):
losses = []
maes = []
outputs = []
fetches = {
'mae': model.mae,
'loss': model.loss,
'global_step': tf.train.get_or_create_global_step()
}
if train_op:
fetches.update({
'train_op': train_op,
})
merged = model.merged
if merged is not None:
fetches.update({'merged': merged})
if return_output:
fetches.update({
'outputs': model.outputs
})
for _, (x, y) in enumerate(zip(inputs, labels)):
feed_dict = {
model.inputs: x,
model.labels: y,
}
vals = sess.run(fetches, feed_dict=feed_dict)
losses.append(vals['loss'])
maes.append(vals['mae'])
if writer is not None and 'merged' in vals:
writer.add_summary(vals['merged'], global_step=vals['global_step'])
if return_output:
outputs.append(vals['outputs'])
results = {
'loss': np.mean(losses),
'mae': np.mean(maes)
}
if return_output:
results['outputs'] = outputs
return results
def get_lr(self, sess):
return np.asscalar(sess.run(self._lr))
def set_lr(self, sess, lr):
sess.run(self._lr_update, feed_dict={
self._new_lr: lr
})
@property
def inputs(self):
return self._inputs
@property
def labels(self):
return self._labels
@property
def loss(self):
return self._loss
@property
def lr(self):
return self._lr
@property
def mae(self):
return self._mae
@property
def merged(self):
return self._merged
@property
def outputs(self):
return self._outputs
@property
def train_op(self):
return self._train_op