Changed README to reflect PyTorch implementation
This commit is contained in:
parent
073f1d4a6e
commit
bb32eb0f46
37
README.md
37
README.md
|
|
@ -2,17 +2,21 @@
|
|||
|
||||

|
||||
|
||||
This is a TensorFlow implementation of Diffusion Convolutional Recurrent Neural Network in the following paper: \
|
||||
This is a PyTorch implementation of Diffusion Convolutional Recurrent Neural Network in the following paper: \
|
||||
Yaguang Li, Rose Yu, Cyrus Shahabi, Yan Liu, [Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting](https://arxiv.org/abs/1707.01926), ICLR 2018.
|
||||
|
||||
|
||||
## Requirements
|
||||
- scipy>=0.19.0
|
||||
- numpy>=1.12.1
|
||||
- pandas>=0.19.2
|
||||
- pyaml
|
||||
- statsmodels
|
||||
- tensorflow>=1.3.0
|
||||
torch
|
||||
scipy>=0.19.0
|
||||
numpy>=1.12.1
|
||||
pandas>=0.19.2
|
||||
pyyaml
|
||||
statsmodels
|
||||
tensorflow>=1.3.0
|
||||
torch
|
||||
tables
|
||||
future
|
||||
|
||||
|
||||
Dependency can be installed using the following command:
|
||||
|
|
@ -60,10 +64,10 @@ Besides, the locations of sensors in Los Angeles, i.e., METR-LA, are available a
|
|||
|
||||
```bash
|
||||
# METR-LA
|
||||
python run_demo.py --config_filename=data/model/pretrained/METR-LA/config.yaml
|
||||
python run_demo_pytorch.py --config_filename=data/model/pretrained/METR-LA/config.yaml
|
||||
|
||||
# PEMS-BAY
|
||||
python run_demo.py --config_filename=data/model/pretrained/PEMS-BAY/config.yaml
|
||||
python run_demo_pytorch.py --config_filename=data/model/pretrained/PEMS-BAY/config.yaml
|
||||
```
|
||||
The generated prediction of DCRNN is in `data/results/dcrnn_predictions`.
|
||||
|
||||
|
|
@ -71,12 +75,11 @@ The generated prediction of DCRNN is in `data/results/dcrnn_predictions`.
|
|||
## Model Training
|
||||
```bash
|
||||
# METR-LA
|
||||
python dcrnn_train.py --config_filename=data/model/dcrnn_la.yaml
|
||||
python dcrnn_train_pytorch.py --config_filename=data/model/dcrnn_la.yaml
|
||||
|
||||
# PEMS-BAY
|
||||
python dcrnn_train.py --config_filename=data/model/dcrnn_bay.yaml
|
||||
python dcrnn_train_pytorch.py --config_filename=data/model/dcrnn_bay.yaml
|
||||
```
|
||||
Each epoch takes about 5min or 10 min on a single GTX 1080 Ti for METR-LA or PEMS-BAY respectively.
|
||||
|
||||
There is a chance that the training loss will explode, the temporary workaround is to restart from the last saved model before the explosion, or to decrease the learning rate earlier in the learning rate schedule.
|
||||
|
||||
|
|
@ -87,7 +90,15 @@ There is a chance that the training loss will explode, the temporary workaround
|
|||
python -m scripts.eval_baseline_methods --traffic_reading_filename=data/metr-la.h5
|
||||
```
|
||||
|
||||
More details are being added ...
|
||||
### PyTorch Results
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
## Citation
|
||||
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -3,12 +3,12 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import tensorflow as tf
|
||||
import yaml
|
||||
|
||||
from lib.utils import load_graph_data
|
||||
from model.pytorch.dcrnn_supervisor import DCRNNSupervisor
|
||||
|
||||
|
||||
def main(args):
|
||||
with open(args.config_filename) as f:
|
||||
supervisor_config = yaml.load(f)
|
||||
|
|
@ -16,9 +16,6 @@ def main(args):
|
|||
graph_pkl_filename = supervisor_config['data'].get('graph_pkl_filename')
|
||||
sensor_ids, sensor_id_to_ind, adj_mx = load_graph_data(graph_pkl_filename)
|
||||
|
||||
# if args.use_cpu_only:
|
||||
# tf_config = tf.ConfigProto(device_count={'GPU': 0})
|
||||
# with tf.Session(config=tf_config) as sess:
|
||||
supervisor = DCRNNSupervisor(adj_mx=adj_mx, **supervisor_config)
|
||||
|
||||
supervisor.train()
|
||||
|
|
|
|||
Binary file not shown.
|
After Width: | Height: | Size: 203 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 287 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 254 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 205 KiB |
|
|
@ -15,9 +15,6 @@ def run_dcrnn(args):
|
|||
graph_pkl_filename = supervisor_config['data'].get('graph_pkl_filename')
|
||||
sensor_ids, sensor_id_to_ind, adj_mx = load_graph_data(graph_pkl_filename)
|
||||
|
||||
# if args.use_cpu_only:
|
||||
# tf_config = tf.ConfigProto(device_count={'GPU': 0})
|
||||
# with tf.Session(config=tf_config) as sess:
|
||||
supervisor = DCRNNSupervisor(adj_mx=adj_mx, **supervisor_config)
|
||||
mean_score, outputs = supervisor.evaluate('test')
|
||||
np.savez_compressed(args.output_filename, **outputs)
|
||||
|
|
|
|||
Loading…
Reference in New Issue