Adds support for python 3.5 and python 3.6.

This commit is contained in:
Yaguang 2018-04-18 11:50:39 -07:00
parent fcdf62d6de
commit 3e94a0ff0e
4 changed files with 20 additions and 7 deletions

View File

@ -14,7 +14,7 @@ Yaguang Li, Rose Yu, Cyrus Shahabi, Yan Liu, [Diffusion Convolutional Recurrent
- numpy>=1.12.1
- pandas>=0.19.2
- tensorflow>=1.3.0
- python 2.7
Dependency can be installed using the following command:
```bash
@ -38,7 +38,7 @@ python gen_adj_mx.py --sensor_ids_filename=data/sensor_graph/graph_sensor_ids.t
## Train the Model
```bash
python dcrnn_train.py --config_filename=data/model/dcrnn_config.json
python dcrnn_train.py --config_filename=data/model/dcrnn_config.yaml
```
@ -59,7 +59,7 @@ If you find this repository useful in your research, please cite the following p
@inproceedings{li2018dcrnn_traffic,
title={Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting},
author={Li, Yaguang and Yu, Rose and Shahabi, Cyrus and Liu, Yan},
booktitle={International Conference on Learning Representations},
booktitle={International Conference on Learning Representations (ICLR '18)},
year={2018}
}
```

View File

@ -59,5 +59,5 @@ if __name__ == '__main__':
distance_df = pd.read_csv(FLAGS.distances_filename, dtype={'from': 'str', 'to': 'str'})
_, sensor_id_to_ind, adj_mx = get_adjacency_matrix(distance_df, sensor_ids)
# Save to pickle file.
with open(FLAGS.output_pkl_filename, 'w') as f:
pickle.dump([sensor_ids, sensor_id_to_ind, adj_mx], f)
with open(FLAGS.output_pkl_filename, 'wb') as f:
pickle.dump([sensor_ids, sensor_id_to_ind, adj_mx], f, protocol=2)

View File

@ -4,11 +4,11 @@ import scipy.sparse as sp
from scipy.sparse import linalg
from lib.tf_utils import sparse_matrix_to_tf_sparse_tensor
from lib.utils import load_pickle
def load_graph_data(pkl_filename):
with open(pkl_filename) as f:
sensor_ids, sensor_id_to_ind, adj_mx = pickle.load(f)
sensor_ids, sensor_id_to_ind, adj_mx = load_pickle(pkl_filename)
return sensor_ids, sensor_id_to_ind, adj_mx

View File

@ -1,6 +1,7 @@
import datetime
import numpy as np
import pandas as pd
import pickle
class StandardScaler:
@ -155,6 +156,18 @@ def generate_graph_seq2seq_io_data_with_time(df, batch_size, seq_len, horizon, n
return x, y
def load_pickle(pickle_file):
try:
with open(pickle_file, 'rb') as f:
pickle_data = pickle.load(f)
except UnicodeDecodeError as e:
with open(pickle_file, 'rb') as f:
pickle_data = pickle.load(f, encoding='latin1')
except Exception as e:
print('Unable to load data ', pickle_file, ':', e)
raise
return pickle_data
def round_down(num, divisor):
return num - (num % divisor)