DCRNN/model/dcrnn_supervisor.py

136 lines
6.3 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
import pandas as pd
import tensorflow as tf
from lib.utils import generate_graph_seq2seq_io_data_with_time
from model.dcrnn_model import DCRNNModel
from model.tf_model_supervisor import TFModelSupervisor
class DCRNNSupervisor(TFModelSupervisor):
"""
Do experiments using Graph Random Walk RNN model.
"""
def __init__(self, traffic_reading_df, adj_mx, config):
self._adj_mx = adj_mx
super(DCRNNSupervisor, self).__init__(config, df_data=traffic_reading_df)
def _prepare_train_val_test_data(self):
# Parsing model parameters.
batch_size = self._get_config('batch_size')
horizon = self._get_config('horizon')
seq_len = self._get_config('seq_len')
test_batch_size = 1
add_time_in_day = self._get_config('add_time_in_day')
num_nodes = self._df_train.shape[-1]
x_train, y_train = generate_graph_seq2seq_io_data_with_time(self._df_train,
batch_size=batch_size,
seq_len=seq_len,
horizon=horizon,
num_nodes=num_nodes,
scaler=self._scaler,
add_time_in_day=add_time_in_day,
add_day_in_week=False)
x_val, y_val = generate_graph_seq2seq_io_data_with_time(self._df_val, batch_size=batch_size,
seq_len=seq_len,
horizon=horizon,
num_nodes=num_nodes,
scaler=self._scaler,
add_time_in_day=add_time_in_day,
add_day_in_week=False)
x_test, y_test = generate_graph_seq2seq_io_data_with_time(self._df_test,
batch_size=test_batch_size,
seq_len=seq_len,
horizon=horizon,
num_nodes=num_nodes,
scaler=self._scaler,
add_time_in_day=add_time_in_day,
add_day_in_week=False)
return x_train, y_train, x_val, y_val, x_test, y_test
def _build_train_val_test_models(self):
# Builds the model.
input_dim = self._x_train.shape[-1]
num_nodes = self._df_test.shape[-1]
output_dim = self._y_train.shape[-1]
test_batch_size = self._get_config('test_batch_size')
train_config = dict(self._config)
train_config.update({
'input_dim': input_dim,
'num_nodes': num_nodes,
'output_dim': output_dim,
})
test_config = dict(self._config)
test_config.update({
'batch_size': test_batch_size,
'input_dim': input_dim,
'num_nodes': num_nodes,
'output_dim': output_dim,
})
with tf.name_scope('Train'):
with tf.variable_scope('DCRNN', reuse=False):
train_model = DCRNNModel(is_training=True, config=train_config, scaler=self._scaler,
adj_mx=self._adj_mx)
with tf.name_scope('Val'):
with tf.variable_scope('DCRNN', reuse=True):
val_model = DCRNNModel(is_training=False, config=train_config, scaler=self._scaler,
adj_mx=self._adj_mx)
with tf.name_scope('Test'):
with tf.variable_scope('DCRNN', reuse=True):
test_model = DCRNNModel(is_training=False, config=test_config, scaler=self._scaler,
adj_mx=self._adj_mx)
return train_model, val_model, test_model
def _convert_model_outputs_to_eval_df(self, y_preds):
y_preds = np.stack(y_preds, axis=1)
# y_preds: (batch_size, epoch_size, horizon, num_nodes, output_dim)
# horizon = y_preds.shape[2]
horizon = self._get_config('horizon')
num_nodes = self._df_train.shape[-1]
df_preds = {}
for horizon_i in range(horizon):
y_pred = np.reshape(y_preds[:, :, horizon_i, :, 0], self._eval_dfs[horizon_i].shape)
df_pred = pd.DataFrame(self._scaler.inverse_transform(y_pred), index=self._eval_dfs[horizon_i].index,
columns=self._eval_dfs[horizon_i].columns)
df_preds[horizon_i] = df_pred
return df_preds
@staticmethod
def _generate_run_id(config):
batch_size = config.get('batch_size')
dropout = config.get('dropout')
learning_rate = config.get('learning_rate')
loss_func = config.get('loss_func')
max_diffusion_step = config['max_diffusion_step']
num_rnn_layers = config.get('num_rnn_layers')
rnn_units = config.get('rnn_units')
seq_len = config.get('seq_len')
structure = '-'.join(
['%d' % rnn_units for _ in range(num_rnn_layers)])
horizon = config.get('horizon')
filter_type = config.get('filter_type')
filter_type_abbr = 'L'
if filter_type == 'random_walk':
filter_type_abbr = 'R'
elif filter_type == 'dual_random_walk':
filter_type_abbr = 'DR'
run_id = 'dcrnn_%s_%d_h_%d_%s_lr_%g_bs_%d_d_%.2f_sl_%d_%s_%s/' % (
filter_type_abbr, max_diffusion_step, horizon,
structure, learning_rate, batch_size,
dropout, seq_len, loss_func,
time.strftime('%m%d%H%M%S'))
return run_id