88 lines
4.4 KiB
Python
88 lines
4.4 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import json
|
|
|
|
import pandas as pd
|
|
import tensorflow as tf
|
|
|
|
from lib import log_helper
|
|
from lib.dcrnn_utils import load_graph_data
|
|
from model.dcrnn_supervisor import DCRNNSupervisor
|
|
|
|
# flags
|
|
flags = tf.app.flags
|
|
FLAGS = flags.FLAGS
|
|
|
|
flags.DEFINE_integer('batch_size', -1, 'Batch size')
|
|
flags.DEFINE_integer('cl_decay_steps', -1,
|
|
'Parameter to control the decay speed of probability of feeding groundth instead of model output.')
|
|
flags.DEFINE_string('config_filename', None, 'Configuration filename for restoring the model.')
|
|
flags.DEFINE_integer('epochs', -1, 'Maximum number of epochs to train.')
|
|
flags.DEFINE_string('filter_type', None, 'laplacian/random_walk/dual_random_walk.')
|
|
flags.DEFINE_string('graph_pkl_filename', 'data/sensor_graph/adj_mx.pkl',
|
|
'Pickle file containing: sensor_ids, sensor_id_to_ind_map, dist_matrix')
|
|
flags.DEFINE_integer('horizon', -1, 'Maximum number of timestamps to prediction.')
|
|
flags.DEFINE_float('l1_decay', -1.0, 'L1 Regularization')
|
|
flags.DEFINE_float('lr_decay', -1.0, 'Learning rate decay.')
|
|
flags.DEFINE_integer('lr_decay_epoch', -1, 'The epoch that starting decaying the parameter.')
|
|
flags.DEFINE_integer('lr_decay_interval', -1, 'Interval beteween each deacy.')
|
|
flags.DEFINE_float('learning_rate', -1, 'Learning rate. -1: select by hyperopt tuning.')
|
|
flags.DEFINE_string('log_dir', None, 'Log directory for restoring the model from a checkpoint.')
|
|
flags.DEFINE_string('loss_func', None, 'MSE/MAPE/RMSE_MAPE: loss function.')
|
|
flags.DEFINE_float('min_learning_rate', -1, 'Minimum learning rate')
|
|
flags.DEFINE_integer('nb_weeks', 17, 'How many week\'s data should be used for train/test.')
|
|
flags.DEFINE_integer('patience', -1,
|
|
'Maximum number of epochs allowed for non-improving validation error before early stopping.')
|
|
flags.DEFINE_integer('seq_len', -1, 'Sequence length.')
|
|
flags.DEFINE_integer('test_every_n_epochs', -1, 'Run model on the testing dataset every n epochs.')
|
|
flags.DEFINE_string('traffic_df_filename', 'data/df_highway_2012_4mon_sample.h5',
|
|
'Path to hdf5 pandas.DataFrame.')
|
|
flags.DEFINE_bool('use_cpu_only', False, 'Set to true to only use cpu.')
|
|
flags.DEFINE_bool('use_curriculum_learning', None, 'Set to true to use Curriculum learning in decoding stage.')
|
|
flags.DEFINE_integer('verbose', -1, '1: to log individual sensor information.')
|
|
|
|
|
|
def main():
|
|
# Reads graph data.
|
|
with open(FLAGS.config_filename) as f:
|
|
supervisor_config = json.load(f)
|
|
logger = log_helper.get_logger(supervisor_config.get('base_dir'), 'info.log')
|
|
logger.info('Loading graph from: ' + FLAGS.graph_pkl_filename)
|
|
sensor_ids, sensor_id_to_ind, adj_mx = load_graph_data(FLAGS.graph_pkl_filename)
|
|
adj_mx[adj_mx < 0.1] = 0
|
|
logger.info('Loading traffic data from: ' + FLAGS.traffic_df_filename)
|
|
traffic_df_filename = FLAGS.traffic_df_filename
|
|
traffic_reading_df = pd.read_hdf(traffic_df_filename)
|
|
traffic_reading_df = traffic_reading_df.ix[:, sensor_ids]
|
|
supervisor_config['use_cpu_only'] = FLAGS.use_cpu_only
|
|
if FLAGS.log_dir:
|
|
supervisor_config['log_dir'] = FLAGS.log_dir
|
|
if FLAGS.use_curriculum_learning is not None:
|
|
supervisor_config['use_curriculum_learning'] = FLAGS.use_curriculum_learning
|
|
if FLAGS.loss_func:
|
|
supervisor_config['loss_func'] = FLAGS.loss_func
|
|
if FLAGS.filter_type:
|
|
supervisor_config['filter_type'] = FLAGS.filter_type
|
|
# Overwrites space with specified parameters.
|
|
for name in ['batch_size', 'cl_decay_steps', 'epochs', 'horizon', 'learning_rate', 'l1_decay',
|
|
'lr_decay', 'lr_decay_epoch', 'lr_decay_interval', 'learning_rate', 'min_learning_rate',
|
|
'patience', 'seq_len', 'test_every_n_epochs', 'verbose']:
|
|
if getattr(FLAGS, name) >= 0:
|
|
supervisor_config[name] = getattr(FLAGS, name)
|
|
|
|
tf_config = tf.ConfigProto()
|
|
if FLAGS.use_cpu_only:
|
|
tf_config = tf.ConfigProto(device_count={'GPU': 0})
|
|
tf_config.gpu_options.allow_growth = True
|
|
with tf.Session(config=tf_config) as sess:
|
|
supervisor = DCRNNSupervisor(traffic_reading_df=traffic_reading_df, adj_mx=adj_mx,
|
|
config=supervisor_config)
|
|
|
|
supervisor.train(sess=sess)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|