From 3e94a0ff0e03de40007d5daecf70c0c7b8fa195f Mon Sep 17 00:00:00 2001 From: Yaguang Date: Wed, 18 Apr 2018 11:50:39 -0700 Subject: [PATCH] Adds support for python 3.5 and python 3.6. --- README.md | 6 +++--- gen_adj_mx.py | 4 ++-- lib/dcrnn_utils.py | 4 ++-- lib/utils.py | 13 +++++++++++++ 4 files changed, 20 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 045a812..ad2c350 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Yaguang Li, Rose Yu, Cyrus Shahabi, Yan Liu, [Diffusion Convolutional Recurrent - numpy>=1.12.1 - pandas>=0.19.2 - tensorflow>=1.3.0 -- python 2.7 + Dependency can be installed using the following command: ```bash @@ -38,7 +38,7 @@ python gen_adj_mx.py --sensor_ids_filename=data/sensor_graph/graph_sensor_ids.t ## Train the Model ```bash -python dcrnn_train.py --config_filename=data/model/dcrnn_config.json +python dcrnn_train.py --config_filename=data/model/dcrnn_config.yaml ``` @@ -59,7 +59,7 @@ If you find this repository useful in your research, please cite the following p @inproceedings{li2018dcrnn_traffic, title={Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting}, author={Li, Yaguang and Yu, Rose and Shahabi, Cyrus and Liu, Yan}, - booktitle={International Conference on Learning Representations}, + booktitle={International Conference on Learning Representations (ICLR '18)}, year={2018} } ``` diff --git a/gen_adj_mx.py b/gen_adj_mx.py index 7cedca6..4e212e8 100644 --- a/gen_adj_mx.py +++ b/gen_adj_mx.py @@ -59,5 +59,5 @@ if __name__ == '__main__': distance_df = pd.read_csv(FLAGS.distances_filename, dtype={'from': 'str', 'to': 'str'}) _, sensor_id_to_ind, adj_mx = get_adjacency_matrix(distance_df, sensor_ids) # Save to pickle file. - with open(FLAGS.output_pkl_filename, 'w') as f: - pickle.dump([sensor_ids, sensor_id_to_ind, adj_mx], f) + with open(FLAGS.output_pkl_filename, 'wb') as f: + pickle.dump([sensor_ids, sensor_id_to_ind, adj_mx], f, protocol=2) diff --git a/lib/dcrnn_utils.py b/lib/dcrnn_utils.py index f0ee54d..ee4a34e 100644 --- a/lib/dcrnn_utils.py +++ b/lib/dcrnn_utils.py @@ -4,11 +4,11 @@ import scipy.sparse as sp from scipy.sparse import linalg from lib.tf_utils import sparse_matrix_to_tf_sparse_tensor +from lib.utils import load_pickle def load_graph_data(pkl_filename): - with open(pkl_filename) as f: - sensor_ids, sensor_id_to_ind, adj_mx = pickle.load(f) + sensor_ids, sensor_id_to_ind, adj_mx = load_pickle(pkl_filename) return sensor_ids, sensor_id_to_ind, adj_mx diff --git a/lib/utils.py b/lib/utils.py index b4cd3ff..38c12fa 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -1,6 +1,7 @@ import datetime import numpy as np import pandas as pd +import pickle class StandardScaler: @@ -155,6 +156,18 @@ def generate_graph_seq2seq_io_data_with_time(df, batch_size, seq_len, horizon, n return x, y +def load_pickle(pickle_file): + try: + with open(pickle_file, 'rb') as f: + pickle_data = pickle.load(f) + except UnicodeDecodeError as e: + with open(pickle_file, 'rb') as f: + pickle_data = pickle.load(f, encoding='latin1') + except Exception as e: + print('Unable to load data ', pickle_file, ':', e) + raise + return pickle_data + def round_down(num, divisor): return num - (num % divisor)