Changed README to reflect PyTorch implementation

This commit is contained in:
Chintan Shah 2019-10-30 12:30:45 -04:00
parent 073f1d4a6e
commit bb32eb0f46
8 changed files with 25 additions and 20 deletions

View File

@ -2,17 +2,21 @@
![Diffusion Convolutional Recurrent Neural Network](figures/model_architecture.jpg "Model Architecture")
This is a TensorFlow implementation of Diffusion Convolutional Recurrent Neural Network in the following paper: \
This is a PyTorch implementation of Diffusion Convolutional Recurrent Neural Network in the following paper: \
Yaguang Li, Rose Yu, Cyrus Shahabi, Yan Liu, [Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting](https://arxiv.org/abs/1707.01926), ICLR 2018.
## Requirements
- scipy>=0.19.0
- numpy>=1.12.1
- pandas>=0.19.2
- pyaml
- statsmodels
- tensorflow>=1.3.0
torch
scipy>=0.19.0
numpy>=1.12.1
pandas>=0.19.2
pyyaml
statsmodels
tensorflow>=1.3.0
torch
tables
future
Dependency can be installed using the following command:
@ -60,10 +64,10 @@ Besides, the locations of sensors in Los Angeles, i.e., METR-LA, are available a
```bash
# METR-LA
python run_demo.py --config_filename=data/model/pretrained/METR-LA/config.yaml
python run_demo_pytorch.py --config_filename=data/model/pretrained/METR-LA/config.yaml
# PEMS-BAY
python run_demo.py --config_filename=data/model/pretrained/PEMS-BAY/config.yaml
python run_demo_pytorch.py --config_filename=data/model/pretrained/PEMS-BAY/config.yaml
```
The generated prediction of DCRNN is in `data/results/dcrnn_predictions`.
@ -71,12 +75,11 @@ The generated prediction of DCRNN is in `data/results/dcrnn_predictions`.
## Model Training
```bash
# METR-LA
python dcrnn_train.py --config_filename=data/model/dcrnn_la.yaml
python dcrnn_train_pytorch.py --config_filename=data/model/dcrnn_la.yaml
# PEMS-BAY
python dcrnn_train.py --config_filename=data/model/dcrnn_bay.yaml
python dcrnn_train_pytorch.py --config_filename=data/model/dcrnn_bay.yaml
```
Each epoch takes about 5min or 10 min on a single GTX 1080 Ti for METR-LA or PEMS-BAY respectively.
There is a chance that the training loss will explode, the temporary workaround is to restart from the last saved model before the explosion, or to decrease the learning rate earlier in the learning rate schedule.
@ -87,7 +90,15 @@ There is a chance that the training loss will explode, the temporary workaround
python -m scripts.eval_baseline_methods --traffic_reading_filename=data/metr-la.h5
```
More details are being added ...
### PyTorch Results
![PyTorch Results](figures/result1.png "PyTorch Results")
![PyTorch Results](figures/result2.png "PyTorch Results")
![PyTorch Results](figures/result3.png "PyTorch Results")
![PyTorch Results](figures/result4.png "PyTorch Results")
## Citation

Binary file not shown.

View File

@ -3,12 +3,12 @@ from __future__ import division
from __future__ import print_function
import argparse
import tensorflow as tf
import yaml
from lib.utils import load_graph_data
from model.pytorch.dcrnn_supervisor import DCRNNSupervisor
def main(args):
with open(args.config_filename) as f:
supervisor_config = yaml.load(f)
@ -16,9 +16,6 @@ def main(args):
graph_pkl_filename = supervisor_config['data'].get('graph_pkl_filename')
sensor_ids, sensor_id_to_ind, adj_mx = load_graph_data(graph_pkl_filename)
# if args.use_cpu_only:
# tf_config = tf.ConfigProto(device_count={'GPU': 0})
# with tf.Session(config=tf_config) as sess:
supervisor = DCRNNSupervisor(adj_mx=adj_mx, **supervisor_config)
supervisor.train()

BIN
figures/result1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 203 KiB

BIN
figures/result2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 287 KiB

BIN
figures/result3.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 254 KiB

BIN
figures/result4.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 205 KiB

View File

@ -15,9 +15,6 @@ def run_dcrnn(args):
graph_pkl_filename = supervisor_config['data'].get('graph_pkl_filename')
sensor_ids, sensor_id_to_ind, adj_mx = load_graph_data(graph_pkl_filename)
# if args.use_cpu_only:
# tf_config = tf.ConfigProto(device_count={'GPU': 0})
# with tf.Session(config=tf_config) as sess:
supervisor = DCRNNSupervisor(adj_mx=adj_mx, **supervisor_config)
mean_score, outputs = supervisor.evaluate('test')
np.savez_compressed(args.output_filename, **outputs)