diff --git a/README.md b/README.md index 3b7a458..da283e1 100644 --- a/README.md +++ b/README.md @@ -2,17 +2,21 @@ ![Diffusion Convolutional Recurrent Neural Network](figures/model_architecture.jpg "Model Architecture") -This is a TensorFlow implementation of Diffusion Convolutional Recurrent Neural Network in the following paper: \ +This is a PyTorch implementation of Diffusion Convolutional Recurrent Neural Network in the following paper: \ Yaguang Li, Rose Yu, Cyrus Shahabi, Yan Liu, [Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting](https://arxiv.org/abs/1707.01926), ICLR 2018. ## Requirements -- scipy>=0.19.0 -- numpy>=1.12.1 -- pandas>=0.19.2 -- pyaml -- statsmodels -- tensorflow>=1.3.0 +torch +scipy>=0.19.0 +numpy>=1.12.1 +pandas>=0.19.2 +pyyaml +statsmodels +tensorflow>=1.3.0 +torch +tables +future Dependency can be installed using the following command: @@ -60,10 +64,10 @@ Besides, the locations of sensors in Los Angeles, i.e., METR-LA, are available a ```bash # METR-LA -python run_demo.py --config_filename=data/model/pretrained/METR-LA/config.yaml +python run_demo_pytorch.py --config_filename=data/model/pretrained/METR-LA/config.yaml # PEMS-BAY -python run_demo.py --config_filename=data/model/pretrained/PEMS-BAY/config.yaml +python run_demo_pytorch.py --config_filename=data/model/pretrained/PEMS-BAY/config.yaml ``` The generated prediction of DCRNN is in `data/results/dcrnn_predictions`. @@ -71,12 +75,11 @@ The generated prediction of DCRNN is in `data/results/dcrnn_predictions`. ## Model Training ```bash # METR-LA -python dcrnn_train.py --config_filename=data/model/dcrnn_la.yaml +python dcrnn_train_pytorch.py --config_filename=data/model/dcrnn_la.yaml # PEMS-BAY -python dcrnn_train.py --config_filename=data/model/dcrnn_bay.yaml +python dcrnn_train_pytorch.py --config_filename=data/model/dcrnn_bay.yaml ``` -Each epoch takes about 5min or 10 min on a single GTX 1080 Ti for METR-LA or PEMS-BAY respectively. There is a chance that the training loss will explode, the temporary workaround is to restart from the last saved model before the explosion, or to decrease the learning rate earlier in the learning rate schedule. @@ -87,7 +90,15 @@ There is a chance that the training loss will explode, the temporary workaround python -m scripts.eval_baseline_methods --traffic_reading_filename=data/metr-la.h5 ``` -More details are being added ... +### PyTorch Results + +![PyTorch Results](figures/result1.png "PyTorch Results") + +![PyTorch Results](figures/result2.png "PyTorch Results") + +![PyTorch Results](figures/result3.png "PyTorch Results") + +![PyTorch Results](figures/result4.png "PyTorch Results") ## Citation diff --git a/data/dcrnn_predictions_pytorch.npz b/data/dcrnn_predictions_pytorch.npz new file mode 100644 index 0000000..f2ba63a Binary files /dev/null and b/data/dcrnn_predictions_pytorch.npz differ diff --git a/dcrnn_train_pytorch.py b/dcrnn_train_pytorch.py index b01f541..8764fde 100644 --- a/dcrnn_train_pytorch.py +++ b/dcrnn_train_pytorch.py @@ -3,12 +3,12 @@ from __future__ import division from __future__ import print_function import argparse -import tensorflow as tf import yaml from lib.utils import load_graph_data from model.pytorch.dcrnn_supervisor import DCRNNSupervisor + def main(args): with open(args.config_filename) as f: supervisor_config = yaml.load(f) @@ -16,9 +16,6 @@ def main(args): graph_pkl_filename = supervisor_config['data'].get('graph_pkl_filename') sensor_ids, sensor_id_to_ind, adj_mx = load_graph_data(graph_pkl_filename) - # if args.use_cpu_only: - # tf_config = tf.ConfigProto(device_count={'GPU': 0}) - # with tf.Session(config=tf_config) as sess: supervisor = DCRNNSupervisor(adj_mx=adj_mx, **supervisor_config) supervisor.train() diff --git a/figures/result1.png b/figures/result1.png new file mode 100644 index 0000000..9650fa3 Binary files /dev/null and b/figures/result1.png differ diff --git a/figures/result2.png b/figures/result2.png new file mode 100644 index 0000000..76228bd Binary files /dev/null and b/figures/result2.png differ diff --git a/figures/result3.png b/figures/result3.png new file mode 100644 index 0000000..45b7e17 Binary files /dev/null and b/figures/result3.png differ diff --git a/figures/result4.png b/figures/result4.png new file mode 100644 index 0000000..0cc8136 Binary files /dev/null and b/figures/result4.png differ diff --git a/run_demo_pytorch.py b/run_demo_pytorch.py index caa9dbe..714facd 100644 --- a/run_demo_pytorch.py +++ b/run_demo_pytorch.py @@ -15,9 +15,6 @@ def run_dcrnn(args): graph_pkl_filename = supervisor_config['data'].get('graph_pkl_filename') sensor_ids, sensor_id_to_ind, adj_mx = load_graph_data(graph_pkl_filename) - # if args.use_cpu_only: - # tf_config = tf.ConfigProto(device_count={'GPU': 0}) - # with tf.Session(config=tf_config) as sess: supervisor = DCRNNSupervisor(adj_mx=adj_mx, **supervisor_config) mean_score, outputs = supervisor.evaluate('test') np.savez_compressed(args.output_filename, **outputs)