Go to file
Yaguang 88d9fc86d1 Update instructions for graph generation. 2018-10-01 10:54:56 -07:00
data Update pretrained model. 2018-10-01 09:47:23 -07:00
figures Initial commit. 2017-12-07 17:35:50 -08:00
lib Add AMSGrad from https://github.com/taki0112/AMSGrad-Tensorflow for stablized Adam Training. 2018-09-30 21:52:20 -07:00
model Combine Val and Test model. 2018-09-30 22:14:45 -07:00
scripts Merge log_helper into utils and change logging mechanism. 2018-09-26 11:33:21 -07:00
.gitignore Initial commit. 2017-12-07 17:35:50 -08:00
LICENSE Initial commit. 2017-12-07 17:35:50 -08:00
README.md Update instructions for graph generation. 2018-10-01 10:54:56 -07:00
dcrnn_train.py Code refactoring, including data loading, logging, configuration, removing redundant code. 2018-09-26 11:19:00 -07:00
requirements.txt Update README and requirements. 2018-09-30 22:15:27 -07:00
run_demo.py Update pretrained model. 2018-10-01 09:47:23 -07:00

README.md

Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting

Dependency Status

Diffusion Convolutional Recurrent Neural Network

This is a TensorFlow implementation of Diffusion Convolutional Recurrent Neural Network in the following paper:
Yaguang Li, Rose Yu, Cyrus Shahabi, Yan Liu, Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting, ICLR 2018.

Requirements

  • scipy>=0.19.0
  • numpy>=1.12.1
  • pandas>=0.19.2
  • tensorflow>=1.3.0
  • pyaml

Dependency can be installed using the following command:

pip install -r requirements.txt

Data Preparation

The traffic data file for Los Angeles, i.e., df_highway_2012_4mon_sample.h5, is available here, and should be put into the data/METR-LA folder. Besides, the locations of sensors are available at data/sensor_graph/graph_sensor_locations.csv.

python -m scripts.generate_training_data --output_dir=data/METR-LA

The generated train/val/test dataset will be saved at data/METR-LA/{train,val,test}.npz.

Run the Pre-trained Model

python run_demo.py

The generated prediction of DCRNN is in data/results/dcrnn_predictions_[1-12].h5.

Model Training

python dcrnn_train.py --config_filename=data/model/dcrnn_config.yaml

Each epoch takes about 5min with a single GTX 1080 Ti.

Graph Construction

As the currently implementation is based on pre-calculated road network distances between sensors, it currently only supports sensor ids in Los Angeles (see data/sensor_graph/sensor_info_201206.csv).

python -m scripts.gen_adj_mx.py  --sensor_ids_filename=data/sensor_graph/graph_sensor_ids.txt --normalized_k=0.1\
    --output_pkl_filename=data/sensor_graph/adj_mx.pkl

More details are being added ...

Citation

If you find this repository useful in your research, please cite the following paper:

@inproceedings{li2018dcrnn_traffic,
  title={Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting},
  author={Li, Yaguang and Yu, Rose and Shahabi, Cyrus and Liu, Yan},
  booktitle={International Conference on Learning Representations (ICLR '18)},
  year={2018}
}