DCRNN/dcrnn_train.py

88 lines
4.4 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import pandas as pd
import tensorflow as tf
from lib import log_helper
from lib.dcrnn_utils import load_graph_data
from model.dcrnn_supervisor import DCRNNSupervisor
# flags
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer('batch_size', -1, 'Batch size')
flags.DEFINE_integer('cl_decay_steps', -1,
'Parameter to control the decay speed of probability of feeding groundth instead of model output.')
flags.DEFINE_string('config_filename', None, 'Configuration filename for restoring the model.')
flags.DEFINE_integer('epochs', -1, 'Maximum number of epochs to train.')
flags.DEFINE_string('filter_type', None, 'laplacian/random_walk/dual_random_walk.')
flags.DEFINE_string('graph_pkl_filename', 'data/sensor_graph/adj_mx.pkl',
'Pickle file containing: sensor_ids, sensor_id_to_ind_map, dist_matrix')
flags.DEFINE_integer('horizon', -1, 'Maximum number of timestamps to prediction.')
flags.DEFINE_float('l1_decay', -1.0, 'L1 Regularization')
flags.DEFINE_float('lr_decay', -1.0, 'Learning rate decay.')
flags.DEFINE_integer('lr_decay_epoch', -1, 'The epoch that starting decaying the parameter.')
flags.DEFINE_integer('lr_decay_interval', -1, 'Interval beteween each deacy.')
flags.DEFINE_float('learning_rate', -1, 'Learning rate. -1: select by hyperopt tuning.')
flags.DEFINE_string('log_dir', None, 'Log directory for restoring the model from a checkpoint.')
flags.DEFINE_string('loss_func', None, 'MSE/MAPE/RMSE_MAPE: loss function.')
flags.DEFINE_float('min_learning_rate', -1, 'Minimum learning rate')
flags.DEFINE_integer('nb_weeks', 17, 'How many week\'s data should be used for train/test.')
flags.DEFINE_integer('patience', -1,
'Maximum number of epochs allowed for non-improving validation error before early stopping.')
flags.DEFINE_integer('seq_len', -1, 'Sequence length.')
flags.DEFINE_integer('test_every_n_epochs', -1, 'Run model on the testing dataset every n epochs.')
flags.DEFINE_string('traffic_df_filename', 'data/df_highway_2012_4mon_sample.h5',
'Path to hdf5 pandas.DataFrame.')
flags.DEFINE_bool('use_cpu_only', False, 'Set to true to only use cpu.')
flags.DEFINE_bool('use_curriculum_learning', None, 'Set to true to use Curriculum learning in decoding stage.')
flags.DEFINE_integer('verbose', -1, '1: to log individual sensor information.')
def main():
# Reads graph data.
with open(FLAGS.config_filename) as f:
supervisor_config = json.load(f)
logger = log_helper.get_logger(supervisor_config.get('base_dir'), 'info.log')
logger.info('Loading graph from: ' + FLAGS.graph_pkl_filename)
sensor_ids, sensor_id_to_ind, adj_mx = load_graph_data(FLAGS.graph_pkl_filename)
adj_mx[adj_mx < 0.1] = 0
logger.info('Loading traffic data from: ' + FLAGS.traffic_df_filename)
traffic_df_filename = FLAGS.traffic_df_filename
traffic_reading_df = pd.read_hdf(traffic_df_filename)
traffic_reading_df = traffic_reading_df.ix[:, sensor_ids]
supervisor_config['use_cpu_only'] = FLAGS.use_cpu_only
if FLAGS.log_dir:
supervisor_config['log_dir'] = FLAGS.log_dir
if FLAGS.use_curriculum_learning is not None:
supervisor_config['use_curriculum_learning'] = FLAGS.use_curriculum_learning
if FLAGS.loss_func:
supervisor_config['loss_func'] = FLAGS.loss_func
if FLAGS.filter_type:
supervisor_config['filter_type'] = FLAGS.filter_type
# Overwrites space with specified parameters.
for name in ['batch_size', 'cl_decay_steps', 'epochs', 'horizon', 'learning_rate', 'l1_decay',
'lr_decay', 'lr_decay_epoch', 'lr_decay_interval', 'learning_rate', 'min_learning_rate',
'patience', 'seq_len', 'test_every_n_epochs', 'verbose']:
if getattr(FLAGS, name) >= 0:
supervisor_config[name] = getattr(FLAGS, name)
tf_config = tf.ConfigProto()
if FLAGS.use_cpu_only:
tf_config = tf.ConfigProto(device_count={'GPU': 0})
tf_config.gpu_options.allow_growth = True
with tf.Session(config=tf_config) as sess:
supervisor = DCRNNSupervisor(traffic_reading_df=traffic_reading_df, adj_mx=adj_mx,
config=supervisor_config)
supervisor.train(sess=sess)
if __name__ == '__main__':
main()