diff --git a/README.md b/README.md index cfebdac..8bb9ff2 100644 --- a/README.md +++ b/README.md @@ -39,9 +39,13 @@ The generated train/val/test dataset will be saved at `data/{METR-LA,PEMS-BAY}/{ ## Run the Pre-trained Model on METR-LA ```bash -python run_demo.py +# METR-LA +python run_demo.py --config_filename=data/model/pretrained/METR-LA/config.yaml + +# PEMS-BAY +python run_demo.py --config_filename=data/model/pretrained/PEMS-BAY/config.yaml ``` -The generated prediction of DCRNN of METR-LA is in `data/results/dcrnn_predictions_[1-12].h5`. +The generated prediction of DCRNN is in `data/results/dcrnn_predictions`. ## Model Training diff --git a/data/model/pretrained/config.yaml b/data/model/pretrained/METR-LA/config.yaml similarity index 86% rename from data/model/pretrained/config.yaml rename to data/model/pretrained/METR-LA/config.yaml index 5292f70..b0227a8 100644 --- a/data/model/pretrained/config.yaml +++ b/data/model/pretrained/METR-LA/config.yaml @@ -25,12 +25,12 @@ train: epochs: 100 epsilon: 0.001 global_step: 24375 - log_dir: data/model/pretrained/ + log_dir: data/model/pretrained/METR-LA lr_decay_ratio: 0.1 max_grad_norm: 5 max_to_keep: 100 min_learning_rate: 2.0e-06 - model_filename: data/model/pretrained/models-2.7422-24375 + model_filename: data/model/pretrained/METR-LA/models-2.7422-24375 optimizer: adam patience: 50 steps: diff --git a/data/model/pretrained/models-2.7422-24375.data-00000-of-00001 b/data/model/pretrained/METR-LA/models-2.7422-24375.data-00000-of-00001 similarity index 100% rename from data/model/pretrained/models-2.7422-24375.data-00000-of-00001 rename to data/model/pretrained/METR-LA/models-2.7422-24375.data-00000-of-00001 diff --git a/data/model/pretrained/models-2.7422-24375.index b/data/model/pretrained/METR-LA/models-2.7422-24375.index similarity index 100% rename from data/model/pretrained/models-2.7422-24375.index rename to data/model/pretrained/METR-LA/models-2.7422-24375.index diff --git a/data/model/pretrained/PEMS-BAY/config.yaml b/data/model/pretrained/PEMS-BAY/config.yaml new file mode 100644 index 0000000..3fb9042 --- /dev/null +++ b/data/model/pretrained/PEMS-BAY/config.yaml @@ -0,0 +1,42 @@ +base_dir: data/model +data: + batch_size: 64 + dataset_dir: data/PEMS-BAY + graph_pkl_filename: data/sensor_graph/adj_mx_bay.pkl + test_batch_size: 64 + val_batch_size: 64 +log_level: INFO +model: + cl_decay_steps: 2000 + filter_type: dual_random_walk + horizon: 12 + input_dim: 2 + l1_decay: 0 + max_diffusion_step: 2 + num_nodes: 325 + num_rnn_layers: 2 + output_dim: 1 + rnn_units: 64 + seq_len: 12 + use_curriculum_learning: true +train: + base_lr: 0.01 + dropout: 0 + epoch: 53 + epochs: 100 + epsilon: 0.001 + global_step: 30780 + log_dir: data/model/pretrained/PEMS-BAY/ + lr_decay_ratio: 0.1 + max_grad_norm: 5 + max_to_keep: 100 + min_learning_rate: 2.0e-06 + model_filename: data/model/pretrained/PEMS-BAY/models-1.6139-30780 + optimizer: adam + patience: 50 + steps: + - 20 + - 30 + - 40 + - 50 + test_every_n_epochs: 10 diff --git a/data/model/pretrained/PEMS-BAY/events.out.tfevents.1547170277.kakarot b/data/model/pretrained/PEMS-BAY/events.out.tfevents.1547170277.kakarot new file mode 100644 index 0000000..e829a9a Binary files /dev/null and b/data/model/pretrained/PEMS-BAY/events.out.tfevents.1547170277.kakarot differ diff --git a/data/model/pretrained/PEMS-BAY/models-1.6139-30780.data-00000-of-00001 b/data/model/pretrained/PEMS-BAY/models-1.6139-30780.data-00000-of-00001 new file mode 100644 index 0000000..c6b1b0f Binary files /dev/null and b/data/model/pretrained/PEMS-BAY/models-1.6139-30780.data-00000-of-00001 differ diff --git a/data/model/pretrained/PEMS-BAY/models-1.6139-30780.index b/data/model/pretrained/PEMS-BAY/models-1.6139-30780.index new file mode 100644 index 0000000..6b5fcce Binary files /dev/null and b/data/model/pretrained/PEMS-BAY/models-1.6139-30780.index differ diff --git a/run_demo.py b/run_demo.py index 7c408fc..ecbbe86 100644 --- a/run_demo.py +++ b/run_demo.py @@ -10,13 +10,13 @@ from model.dcrnn_supervisor import DCRNNSupervisor def run_dcrnn(args): - graph_pkl_filename = 'data/sensor_graph/adj_mx.pkl' with open(args.config_filename) as f: config = yaml.load(f) tf_config = tf.ConfigProto() if args.use_cpu_only: tf_config = tf.ConfigProto(device_count={'GPU': 0}) tf_config.gpu_options.allow_growth = True + graph_pkl_filename = config['data']['graph_pkl_filename'] _, _, adj_mx = load_graph_data(graph_pkl_filename) with tf.Session(config=tf_config) as sess: supervisor = DCRNNSupervisor(adj_mx=adj_mx, **config) @@ -30,7 +30,7 @@ if __name__ == '__main__': sys.path.append(os.getcwd()) parser = argparse.ArgumentParser() parser.add_argument('--use_cpu_only', default=False, type=str, help='Whether to run tensorflow on cpu.') - parser.add_argument('--config_filename', default='data/model/pretrained/config.yaml', type=str, + parser.add_argument('--config_filename', default='data/model/pretrained/METR-LA/config.yaml', type=str, help='Config file for pretrained model.') parser.add_argument('--output_filename', default='data/dcrnn_predictions.npz') args = parser.parse_args()