208 lines
7.9 KiB
Python
208 lines
7.9 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
from tensorflow.contrib import legacy_seq2seq
|
|
|
|
from lib.metrics import masked_mse_loss, masked_mae_loss, masked_rmse_loss
|
|
from model.dcrnn_cell import DCGRUCell
|
|
|
|
|
|
class DCRNNModel(object):
|
|
def __init__(self, is_training, batch_size, scaler, adj_mx, **model_kwargs):
|
|
# Scaler for data normalization.
|
|
self._scaler = scaler
|
|
|
|
# Train and loss
|
|
self._loss = None
|
|
self._mae = None
|
|
self._train_op = None
|
|
|
|
# Learning rate.
|
|
self._lr = tf.get_variable('learning_rate', shape=(), initializer=tf.constant_initializer(0.01),
|
|
trainable=False)
|
|
self._new_lr = tf.placeholder(tf.float32, shape=(), name='new_learning_rate')
|
|
self._lr_update = tf.assign(self._lr, self._new_lr, name='lr_update')
|
|
|
|
max_diffusion_step = int(model_kwargs.get('max_diffusion_step', 2))
|
|
cl_decay_steps = int(model_kwargs.get('cl_decay_steps', 1000))
|
|
filter_type = model_kwargs.get('filter_type', 'laplacian')
|
|
horizon = int(model_kwargs.get('horizon', 1))
|
|
max_grad_norm = float(model_kwargs.get('max_grad_norm', 5.0))
|
|
num_nodes = int(model_kwargs.get('num_nodes', 1))
|
|
num_rnn_layers = int(model_kwargs.get('num_rnn_layers', 1))
|
|
rnn_units = int(model_kwargs.get('rnn_units'))
|
|
seq_len = int(model_kwargs.get('seq_len'))
|
|
use_curriculum_learning = bool(model_kwargs.get('use_curriculum_learning', False))
|
|
input_dim = int(model_kwargs.get('input_dim', 1))
|
|
output_dim = int(model_kwargs.get('output_dim', 1))
|
|
aux_dim = input_dim - output_dim
|
|
|
|
# Input (batch_size, timesteps, num_sensor, input_dim)
|
|
self._inputs = tf.placeholder(tf.float32, shape=(batch_size, seq_len, num_nodes, input_dim), name='inputs')
|
|
# Labels: (batch_size, timesteps, num_sensor, input_dim), same format with input except the temporal dimension.
|
|
self._labels = tf.placeholder(tf.float32, shape=(batch_size, horizon, num_nodes, input_dim), name='labels')
|
|
|
|
GO_SYMBOL = tf.zeros(shape=(batch_size, num_nodes * input_dim))
|
|
|
|
cell = DCGRUCell(rnn_units, adj_mx, max_diffusion_step=max_diffusion_step, num_nodes=num_nodes,
|
|
filter_type=filter_type)
|
|
cell_with_projection = DCGRUCell(rnn_units, adj_mx, max_diffusion_step=max_diffusion_step, num_nodes=num_nodes,
|
|
num_proj=output_dim, filter_type=filter_type)
|
|
encoding_cells = [cell] * num_rnn_layers
|
|
decoding_cells = [cell] * (num_rnn_layers - 1) + [cell_with_projection]
|
|
encoding_cells = tf.contrib.rnn.MultiRNNCell(encoding_cells, state_is_tuple=True)
|
|
decoding_cells = tf.contrib.rnn.MultiRNNCell(decoding_cells, state_is_tuple=True)
|
|
|
|
global_step = tf.train.get_or_create_global_step()
|
|
# Outputs: (batch_size, timesteps, num_nodes, output_dim)
|
|
with tf.variable_scope('DCRNN_SEQ'):
|
|
inputs = tf.unstack(tf.reshape(self._inputs, (batch_size, seq_len, num_nodes * input_dim)), axis=1)
|
|
labels = tf.unstack(
|
|
tf.reshape(self._labels[..., :output_dim], (batch_size, horizon, num_nodes * output_dim)), axis=1)
|
|
if aux_dim > 0:
|
|
aux_info = tf.unstack(self._labels[..., output_dim:], axis=1)
|
|
aux_info.insert(0, None)
|
|
labels.insert(0, GO_SYMBOL)
|
|
|
|
def _loop_function(prev, i):
|
|
if is_training:
|
|
# Return either the model's prediction or the previous ground truth in training.
|
|
if use_curriculum_learning:
|
|
c = tf.random_uniform((), minval=0, maxval=1.)
|
|
threshold = self._compute_sampling_threshold(global_step, cl_decay_steps)
|
|
result = tf.cond(tf.less(c, threshold), lambda: labels[i], lambda: prev)
|
|
else:
|
|
result = labels[i]
|
|
else:
|
|
# Return the prediction of the model in testing.
|
|
result = prev
|
|
if aux_dim > 0:
|
|
result = tf.reshape(result, (batch_size, num_nodes, output_dim))
|
|
result = tf.concat([result, aux_info[i]], axis=-1)
|
|
result = tf.reshape(result, (batch_size, num_nodes * input_dim))
|
|
return result
|
|
|
|
_, enc_state = tf.contrib.rnn.static_rnn(encoding_cells, inputs, dtype=tf.float32)
|
|
outputs, final_state = legacy_seq2seq.rnn_decoder(labels, enc_state, decoding_cells,
|
|
loop_function=_loop_function)
|
|
|
|
# Project the output to output_dim.
|
|
outputs = tf.stack(outputs[:-1], axis=1)
|
|
self._outputs = tf.reshape(outputs, (batch_size, horizon, num_nodes, output_dim), name='outputs')
|
|
|
|
preds = self._outputs
|
|
labels = self._labels[..., :output_dim]
|
|
|
|
null_val = 0.
|
|
self._mae = masked_mae_loss(self._scaler, null_val)(preds=preds, labels=labels)
|
|
self._loss = masked_mae_loss(self._scaler, null_val)(preds=preds, labels=labels)
|
|
if is_training:
|
|
optimizer = tf.train.AdamOptimizer(self._lr)
|
|
tvars = tf.trainable_variables()
|
|
grads = tf.gradients(self._loss, tvars)
|
|
grads, _ = tf.clip_by_global_norm(grads, max_grad_norm)
|
|
self._train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=global_step, name='train_op')
|
|
|
|
self._merged = tf.summary.merge_all()
|
|
|
|
@staticmethod
|
|
def _compute_sampling_threshold(global_step, k):
|
|
"""
|
|
Computes the sampling probability for scheduled sampling using inverse sigmoid.
|
|
:param global_step:
|
|
:param k:
|
|
:return:
|
|
"""
|
|
return tf.cast(k / (k + tf.exp(global_step / k)), tf.float32)
|
|
|
|
@staticmethod
|
|
def run_epoch_generator(sess, model, data_generator, return_output=False, train_op=None, writer=None):
|
|
losses = []
|
|
maes = []
|
|
outputs = []
|
|
|
|
fetches = {
|
|
'mae': model.mae,
|
|
'loss': model.loss,
|
|
'global_step': tf.train.get_or_create_global_step()
|
|
}
|
|
if train_op:
|
|
fetches.update({
|
|
'train_op': train_op,
|
|
})
|
|
merged = model.merged
|
|
if merged is not None:
|
|
fetches.update({'merged': merged})
|
|
|
|
if return_output:
|
|
fetches.update({
|
|
'outputs': model.outputs
|
|
})
|
|
|
|
for _, (x, y) in enumerate(data_generator):
|
|
feed_dict = {
|
|
model.inputs: x,
|
|
model.labels: y,
|
|
}
|
|
|
|
vals = sess.run(fetches, feed_dict=feed_dict)
|
|
|
|
losses.append(vals['loss'])
|
|
maes.append(vals['mae'])
|
|
if writer is not None and 'merged' in vals:
|
|
writer.add_summary(vals['merged'], global_step=vals['global_step'])
|
|
if return_output:
|
|
outputs.append(vals['outputs'])
|
|
|
|
results = {
|
|
'loss': np.mean(losses),
|
|
'mae': np.mean(maes)
|
|
}
|
|
if return_output:
|
|
results['outputs'] = outputs
|
|
return results
|
|
|
|
def get_lr(self, sess):
|
|
return np.asscalar(sess.run(self._lr))
|
|
|
|
def set_lr(self, sess, lr):
|
|
sess.run(self._lr_update, feed_dict={
|
|
self._new_lr: lr
|
|
})
|
|
|
|
@property
|
|
def inputs(self):
|
|
return self._inputs
|
|
|
|
@property
|
|
def labels(self):
|
|
return self._labels
|
|
|
|
@property
|
|
def loss(self):
|
|
return self._loss
|
|
|
|
@property
|
|
def lr(self):
|
|
return self._lr
|
|
|
|
@property
|
|
def mae(self):
|
|
return self._mae
|
|
|
|
@property
|
|
def merged(self):
|
|
return self._merged
|
|
|
|
@property
|
|
def outputs(self):
|
|
return self._outputs
|
|
|
|
@property
|
|
def train_op(self):
|
|
return self._train_op
|