From c8a676604b0ff18ff833274b38a0dd1cde00cf09 Mon Sep 17 00:00:00 2001 From: liyaguang Date: Tue, 8 Jan 2019 11:21:22 -0800 Subject: [PATCH] Add PEMS-BAY configuration. --- README.md | 28 ++++++++----- data/model/dcrnn_bay.yaml | 39 +++++++++++++++++++ .../{dcrnn_config.yaml => dcrnn_la.yaml} | 0 3 files changed, 58 insertions(+), 9 deletions(-) create mode 100644 data/model/dcrnn_bay.yaml rename data/model/{dcrnn_config.yaml => dcrnn_la.yaml} (100%) diff --git a/README.md b/README.md index 165a668..bc7060f 100644 --- a/README.md +++ b/README.md @@ -20,28 +20,38 @@ pip install -r requirements.txt ``` ## Data Preparation -The traffic data file for Los Angeles, i.e., `df_highway_2012_4mon_sample.h5`, is available at [Google Drive](https://drive.google.com/open?id=1tjf5aXCgUoimvADyxKqb-YUlxP8O46pb) or [Baidu Yun](https://pan.baidu.com/s/1rsCq38a9SRyFO1F68tUscA), and should be -put into the `data/METR-LA` folder. -Besides, the locations of sensors are available at [data/sensor_graph/graph_sensor_locations.csv](https://github.com/liyaguang/DCRNN/blob/master/data/sensor_graph/graph_sensor_locations.csv). +The traffic data files for Los Angeles and the Bay Area, i.e., `metr-la.h5` and `pems-bay.h5`, are available at [Google Drive](https://drive.google.com/open?id=10FOTa6HXPqX8Pf5WRoRwcFnW9BrNZEIX) or [Baidu Yun](hbttps://pan.baidu.com/s/14Yy9isAIZYdU__OYEQGa_g), and should be +put into the `data/` folder. +Besides, the locations of sensors Los Angeles are available at [data/sensor_graph/graph_sensor_locations.csv](https://github.com/liyaguang/DCRNN/blob/master/data/sensor_graph/graph_sensor_locations.csv). ```bash -python -m scripts.generate_training_data --output_dir=data/METR-LA +mkdir -p data/{METR-LA,PEMS-BAY} + +# METR-LA +python -m scripts.generate_training_data --output_dir=data/METR-LA --traffic_df_filename=data/metr-la.h5 + +# PEMS-BAY +python -m scripts.generate_training_data --output_dir=data/PEMS-BAY --traffic_df_filename=data/pems-bay.h5 ``` -The generated train/val/test dataset will be saved at `data/METR-LA/{train,val,test}.npz`. +The generated train/val/test dataset will be saved at `data/{METR-LA,PEMS-BAY}/{train,val,test}.npz`. -## Run the Pre-trained Model +## Run the Pre-trained Model on METR-LA ```bash python run_demo.py ``` -The generated prediction of DCRNN is in `data/results/dcrnn_predictions_[1-12].h5`. +The generated prediction of DCRNN of METR-LA is in `data/results/dcrnn_predictions_[1-12].h5`. ## Model Training ```bash -python dcrnn_train.py --config_filename=data/model/dcrnn_config.yaml +# METR-LA +python dcrnn_train.py --config_filename=data/model/dcrnn_la.yaml + +# PEMS-BAY +python dcrnn_train.py --config_filename=data/model/dcrnn_bay.yaml ``` -Each epoch takes about 5min with a single GTX 1080 Ti. +Each epoch takes about 5min or 10 min on a single GTX 1080 Ti for METR-LA or PEMS-BAY respectively. ## Graph Construction As the currently implementation is based on pre-calculated road network distances between sensors, it currently only diff --git a/data/model/dcrnn_bay.yaml b/data/model/dcrnn_bay.yaml new file mode 100644 index 0000000..31dec27 --- /dev/null +++ b/data/model/dcrnn_bay.yaml @@ -0,0 +1,39 @@ +--- +base_dir: data/model +log_level: INFO +data: + batch_size: 64 + dataset_dir: data/PEMS-BAY + test_batch_size: 64 + val_batch_size: 64 + graph_pkl_filename: data/sensor_graph/adj_mx_bay.pkl + +model: + cl_decay_steps: 2000 + filter_type: dual_random_walk + horizon: 12 + input_dim: 2 + l1_decay: 0 + max_diffusion_step: 2 + num_nodes: 325 + num_rnn_layers: 2 + output_dim: 1 + rnn_units: 64 + seq_len: 12 + use_curriculum_learning: true + +train: + base_lr: 0.01 + dropout: 0 + epoch: 0 + epochs: 100 + epsilon: 1.0e-3 + global_step: 0 + lr_decay_ratio: 0.1 + max_grad_norm: 5 + max_to_keep: 100 + min_learning_rate: 2.0e-06 + optimizer: adam + patience: 50 + steps: [20, 30, 40, 50] + test_every_n_epochs: 10 diff --git a/data/model/dcrnn_config.yaml b/data/model/dcrnn_la.yaml similarity index 100% rename from data/model/dcrnn_config.yaml rename to data/model/dcrnn_la.yaml