From f92e7295a06c654603e5af2353fc1e0f1cd48e43 Mon Sep 17 00:00:00 2001 From: Chintan Shah Date: Tue, 8 Oct 2019 13:05:49 -0400 Subject: [PATCH] added run_demo_pytorch --- data/model/dcrnn_test_config.yaml | 4 ++-- run_demo_pytorch.py | 35 +++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) create mode 100644 run_demo_pytorch.py diff --git a/data/model/dcrnn_test_config.yaml b/data/model/dcrnn_test_config.yaml index fce08bc..93cbe59 100644 --- a/data/model/dcrnn_test_config.yaml +++ b/data/model/dcrnn_test_config.yaml @@ -18,14 +18,14 @@ model: num_nodes: 207 num_rnn_layers: 2 output_dim: 1 - rnn_units: 16 + rnn_units: 64 seq_len: 12 use_curriculum_learning: true train: base_lr: 0.01 dropout: 0 - epoch: 0 + epoch: 44 epochs: 100 epsilon: 1.0e-3 global_step: 0 diff --git a/run_demo_pytorch.py b/run_demo_pytorch.py new file mode 100644 index 0000000..7e27667 --- /dev/null +++ b/run_demo_pytorch.py @@ -0,0 +1,35 @@ +import argparse +import numpy as np +import os +import sys +import yaml + +from lib.utils import load_graph_data +from model.pytorch.dcrnn_supervisor import DCRNNSupervisor + + +def run_dcrnn(args): + with open(args.config_filename) as f: + supervisor_config = yaml.load(f) + + graph_pkl_filename = supervisor_config['data'].get('graph_pkl_filename') + sensor_ids, sensor_id_to_ind, adj_mx = load_graph_data(graph_pkl_filename) + + # if args.use_cpu_only: + # tf_config = tf.ConfigProto(device_count={'GPU': 0}) + # with tf.Session(config=tf_config) as sess: + supervisor = DCRNNSupervisor(adj_mx=adj_mx, **supervisor_config) + mean_score, outputs = supervisor.evaluate() + np.savez_compressed(args.output_filename, **outputs) + print('Predictions saved as {}.'.format(args.output_filename)) + + +if __name__ == '__main__': + sys.path.append(os.getcwd()) + parser = argparse.ArgumentParser() + parser.add_argument('--use_cpu_only', default=False, type=str, help='Whether to run tensorflow on cpu.') + parser.add_argument('--config_filename', default='data/model/pretrained/METR-LA/config.yaml', type=str, + help='Config file for pretrained model.') + parser.add_argument('--output_filename', default='data/dcrnn_predictions.npz') + args = parser.parse_args() + run_dcrnn(args)