Go to file
Chintan Shah 6331173f44 logging and refactor 2019-10-06 15:22:57 -04:00
data Add pretrained model on PEMS-BAY. 2019-01-10 17:44:07 -08:00
figures Initial commit. 2017-12-07 17:35:50 -08:00
lib Add log level support and load_dataset method. 2018-10-01 17:45:28 -07:00
model logging and refactor 2019-10-06 15:22:57 -04:00
scripts Adds baseline methods for evaluation. 2019-06-18 13:00:24 -07:00
.gitignore Initial commit. 2017-12-07 17:35:50 -08:00
DCRNN_CPU Using pytorch image 2019-09-07 17:53:46 -04:00
LICENSE Initial commit. 2017-12-07 17:35:50 -08:00
README.md Adds baseline methods for evaluation. 2019-06-18 13:00:24 -07:00
dcrnn_train.py Implemented Encoder with GRU - should swap GRU with DCGRU 2019-09-29 12:51:49 -04:00
dcrnn_train_pytorch.py Dirty commit - setup model but [GRUCell] not working, tried ParameterList, did not work 2019-10-02 18:09:33 -04:00
requirements.txt Merge branch 'pytorch_implementation' into pytorch_scratch 2019-10-06 11:49:49 -04:00
run_demo.py Implemented Encoder with GRU - should swap GRU with DCGRU 2019-09-29 12:51:49 -04:00

README.md

Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting

Diffusion Convolutional Recurrent Neural Network

This is a TensorFlow implementation of Diffusion Convolutional Recurrent Neural Network in the following paper:
Yaguang Li, Rose Yu, Cyrus Shahabi, Yan Liu, Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting, ICLR 2018.

Requirements

  • scipy>=0.19.0
  • numpy>=1.12.1
  • pandas>=0.19.2
  • pyaml
  • statsmodels
  • tensorflow>=1.3.0

Dependency can be installed using the following command:

pip install -r requirements.txt

Data Preparation

The traffic data files for Los Angeles (METR-LA) and the Bay Area (PEMS-BAY), i.e., metr-la.h5 and pems-bay.h5, are available at Google Drive or Baidu Yun, and should be put into the data/ folder. The *.h5 files store the data in panads.DataFrame using the HDF5 file format. Here is an example:

sensor_0 sensor_1 sensor_2 sensor_n
2018/01/01 00:00:00 60.0 65.0 70.0 ...
2018/01/01 00:05:00 61.0 64.0 65.0 ...
2018/01/01 00:10:00 63.0 65.0 60.0 ...
... ... ... ... ...

Here is an article about Using HDF5 with Python.

Run the following commands to generate train/test/val dataset at data/{METR-LA,PEMS-BAY}/{train,val,test}.npz.

# Create data directories
mkdir -p data/{METR-LA,PEMS-BAY}

# METR-LA
python -m scripts.generate_training_data --output_dir=data/METR-LA --traffic_df_filename=data/metr-la.h5

# PEMS-BAY
python -m scripts.generate_training_data --output_dir=data/PEMS-BAY --traffic_df_filename=data/pems-bay.h5

Graph Construction

As the currently implementation is based on pre-calculated road network distances between sensors, it currently only supports sensor ids in Los Angeles (see data/sensor_graph/sensor_info_201206.csv).

python -m scripts.gen_adj_mx  --sensor_ids_filename=data/sensor_graph/graph_sensor_ids.txt --normalized_k=0.1\
    --output_pkl_filename=data/sensor_graph/adj_mx.pkl

Besides, the locations of sensors in Los Angeles, i.e., METR-LA, are available at data/sensor_graph/graph_sensor_locations.csv.

Run the Pre-trained Model on METR-LA

# METR-LA
python run_demo.py --config_filename=data/model/pretrained/METR-LA/config.yaml

# PEMS-BAY
python run_demo.py --config_filename=data/model/pretrained/PEMS-BAY/config.yaml

The generated prediction of DCRNN is in data/results/dcrnn_predictions.

Model Training

# METR-LA
python dcrnn_train.py --config_filename=data/model/dcrnn_la.yaml

# PEMS-BAY
python dcrnn_train.py --config_filename=data/model/dcrnn_bay.yaml

Each epoch takes about 5min or 10 min on a single GTX 1080 Ti for METR-LA or PEMS-BAY respectively.

There is a chance that the training loss will explode, the temporary workaround is to restart from the last saved model before the explosion, or to decrease the learning rate earlier in the learning rate schedule.

Eval baseline methods

# METR-LA
python -m scripts.eval_baseline_methods --traffic_reading_filename=data/metr-la.h5

More details are being added ...

Citation

If you find this repository, e.g., the code and the datasets, useful in your research, please cite the following paper:

@inproceedings{li2018dcrnn_traffic,
  title={Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting},
  author={Li, Yaguang and Yu, Rose and Shahabi, Cyrus and Liu, Yan},
  booktitle={International Conference on Learning Representations (ICLR '18)},
  year={2018}
}