diff --git a/README.md b/README.md index b6b84eb..1351ae3 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,6 @@ Yaguang Li, Rose Yu, Cyrus Shahabi, Yan Liu, [Diffusion Convolutional Recurrent ## Requirements -- hyperopt>=0.1 - scipy>=0.19.0 - numpy>=1.12.1 - pandas>=0.19.2 @@ -22,25 +21,14 @@ Dependency can be installed using the following command: pip install -r requirements.txt ``` - -## Traffic Data +## Data Preparation The traffic data file for Los Angeles, i.e., `df_highway_2012_4mon_sample.h5`, is available [here](https://drive.google.com/open?id=1tjf5aXCgUoimvADyxKqb-YUlxP8O46pb), and should be -put into the `data/` folder. +put into the `data/METR-LA` folder. Besides, the locations of sensors are available at [data/sensor_graph/graph_sensor_locations.csv](https://github.com/liyaguang/DCRNN/blob/master/data/sensor_graph/graph_sensor_locations.csv). - -## Graph Construction - As the currently implementation is based on pre-calculated road network distances between sensors, it currently only - supports sensor ids in Los Angeles (see `data/sensor_graph/sensor_info_201206.csv`). - ```bash -python gen_adj_mx.py --sensor_ids_filename=data/sensor_graph/graph_sensor_ids.txt --normalized_k=0.1\ - --output_pkl_filename=data/sensor_graph/adj_mx.pkl -``` - -## Train the Model -```bash -python dcrnn_train.py --config_filename=data/model/dcrnn_config.yaml +python -m scripts.generate_training_data --output_dir=data/METR-LA ``` +The generated train/val/test dataset will be saved at `data/METR-LA/{train,val,test}.npz`. ## Run the Pre-trained Model @@ -51,6 +39,21 @@ python run_demo.py The generated prediction of DCRNN is in `data/results/dcrnn_predictions_[1-12].h5`. +## Model Training +```bash +python dcrnn_train.py --config_filename=data/model/dcrnn_config.yaml +``` +Each epoch takes about 5min with a single GTX 1080 Ti. There is a chance that train/val loss will explode, gradient explosion, + +## Graph Construction + As the currently implementation is based on pre-calculated road network distances between sensors, it currently only + supports sensor ids in Los Angeles (see `data/sensor_graph/sensor_info_201206.csv`). + +```bash +python gen_adj_mx.py --sensor_ids_filename=data/sensor_graph/graph_sensor_ids.txt --normalized_k=0.1\ + --output_pkl_filename=data/sensor_graph/adj_mx.pkl +``` + More details are being added ... ## Citation diff --git a/requirements.txt b/requirements.txt index 6ab6fc5..ec17d78 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -hyperopt>=0.1 scipy>=0.19.0 numpy>=1.12.1 pandas>=0.19.2