diff --git a/run_demo_pytorch.py b/run_demo_pytorch.py index 7e27667..693dfad 100644 --- a/run_demo_pytorch.py +++ b/run_demo_pytorch.py @@ -19,7 +19,7 @@ def run_dcrnn(args): # tf_config = tf.ConfigProto(device_count={'GPU': 0}) # with tf.Session(config=tf_config) as sess: supervisor = DCRNNSupervisor(adj_mx=adj_mx, **supervisor_config) - mean_score, outputs = supervisor.evaluate() + mean_score, outputs = supervisor.evaluate('test') np.savez_compressed(args.output_filename, **outputs) print('Predictions saved as {}.'.format(args.output_filename))