diff --git a/data/model/dcrnn_test_config.yaml b/data/model/dcrnn_test_config.yaml index fce08bc..93cbe59 100644 --- a/data/model/dcrnn_test_config.yaml +++ b/data/model/dcrnn_test_config.yaml @@ -18,14 +18,14 @@ model: num_nodes: 207 num_rnn_layers: 2 output_dim: 1 - rnn_units: 16 + rnn_units: 64 seq_len: 12 use_curriculum_learning: true train: base_lr: 0.01 dropout: 0 - epoch: 0 + epoch: 44 epochs: 100 epsilon: 1.0e-3 global_step: 0 diff --git a/run_demo_pytorch.py b/run_demo_pytorch.py new file mode 100644 index 0000000..7e27667 --- /dev/null +++ b/run_demo_pytorch.py @@ -0,0 +1,35 @@ +import argparse +import numpy as np +import os +import sys +import yaml + +from lib.utils import load_graph_data +from model.pytorch.dcrnn_supervisor import DCRNNSupervisor + + +def run_dcrnn(args): + with open(args.config_filename) as f: + supervisor_config = yaml.load(f) + + graph_pkl_filename = supervisor_config['data'].get('graph_pkl_filename') + sensor_ids, sensor_id_to_ind, adj_mx = load_graph_data(graph_pkl_filename) + + # if args.use_cpu_only: + # tf_config = tf.ConfigProto(device_count={'GPU': 0}) + # with tf.Session(config=tf_config) as sess: + supervisor = DCRNNSupervisor(adj_mx=adj_mx, **supervisor_config) + mean_score, outputs = supervisor.evaluate() + np.savez_compressed(args.output_filename, **outputs) + print('Predictions saved as {}.'.format(args.output_filename)) + + +if __name__ == '__main__': + sys.path.append(os.getcwd()) + parser = argparse.ArgumentParser() + parser.add_argument('--use_cpu_only', default=False, type=str, help='Whether to run tensorflow on cpu.') + parser.add_argument('--config_filename', default='data/model/pretrained/METR-LA/config.yaml', type=str, + help='Config file for pretrained model.') + parser.add_argument('--output_filename', default='data/dcrnn_predictions.npz') + args = parser.parse_args() + run_dcrnn(args)