Adds support for python 3.5 and python 3.6.
This commit is contained in:
parent
fcdf62d6de
commit
3e94a0ff0e
|
|
@ -14,7 +14,7 @@ Yaguang Li, Rose Yu, Cyrus Shahabi, Yan Liu, [Diffusion Convolutional Recurrent
|
|||
- numpy>=1.12.1
|
||||
- pandas>=0.19.2
|
||||
- tensorflow>=1.3.0
|
||||
- python 2.7
|
||||
|
||||
|
||||
Dependency can be installed using the following command:
|
||||
```bash
|
||||
|
|
@ -38,7 +38,7 @@ python gen_adj_mx.py --sensor_ids_filename=data/sensor_graph/graph_sensor_ids.t
|
|||
|
||||
## Train the Model
|
||||
```bash
|
||||
python dcrnn_train.py --config_filename=data/model/dcrnn_config.json
|
||||
python dcrnn_train.py --config_filename=data/model/dcrnn_config.yaml
|
||||
```
|
||||
|
||||
|
||||
|
|
@ -59,7 +59,7 @@ If you find this repository useful in your research, please cite the following p
|
|||
@inproceedings{li2018dcrnn_traffic,
|
||||
title={Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting},
|
||||
author={Li, Yaguang and Yu, Rose and Shahabi, Cyrus and Liu, Yan},
|
||||
booktitle={International Conference on Learning Representations},
|
||||
booktitle={International Conference on Learning Representations (ICLR '18)},
|
||||
year={2018}
|
||||
}
|
||||
```
|
||||
|
|
|
|||
|
|
@ -59,5 +59,5 @@ if __name__ == '__main__':
|
|||
distance_df = pd.read_csv(FLAGS.distances_filename, dtype={'from': 'str', 'to': 'str'})
|
||||
_, sensor_id_to_ind, adj_mx = get_adjacency_matrix(distance_df, sensor_ids)
|
||||
# Save to pickle file.
|
||||
with open(FLAGS.output_pkl_filename, 'w') as f:
|
||||
pickle.dump([sensor_ids, sensor_id_to_ind, adj_mx], f)
|
||||
with open(FLAGS.output_pkl_filename, 'wb') as f:
|
||||
pickle.dump([sensor_ids, sensor_id_to_ind, adj_mx], f, protocol=2)
|
||||
|
|
|
|||
|
|
@ -4,11 +4,11 @@ import scipy.sparse as sp
|
|||
|
||||
from scipy.sparse import linalg
|
||||
from lib.tf_utils import sparse_matrix_to_tf_sparse_tensor
|
||||
from lib.utils import load_pickle
|
||||
|
||||
|
||||
def load_graph_data(pkl_filename):
|
||||
with open(pkl_filename) as f:
|
||||
sensor_ids, sensor_id_to_ind, adj_mx = pickle.load(f)
|
||||
sensor_ids, sensor_id_to_ind, adj_mx = load_pickle(pkl_filename)
|
||||
return sensor_ids, sensor_id_to_ind, adj_mx
|
||||
|
||||
|
||||
|
|
|
|||
13
lib/utils.py
13
lib/utils.py
|
|
@ -1,6 +1,7 @@
|
|||
import datetime
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pickle
|
||||
|
||||
|
||||
class StandardScaler:
|
||||
|
|
@ -155,6 +156,18 @@ def generate_graph_seq2seq_io_data_with_time(df, batch_size, seq_len, horizon, n
|
|||
return x, y
|
||||
|
||||
|
||||
def load_pickle(pickle_file):
|
||||
try:
|
||||
with open(pickle_file, 'rb') as f:
|
||||
pickle_data = pickle.load(f)
|
||||
except UnicodeDecodeError as e:
|
||||
with open(pickle_file, 'rb') as f:
|
||||
pickle_data = pickle.load(f, encoding='latin1')
|
||||
except Exception as e:
|
||||
print('Unable to load data ', pickle_file, ':', e)
|
||||
raise
|
||||
return pickle_data
|
||||
|
||||
def round_down(num, divisor):
|
||||
return num - (num % divisor)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue