Add pretrained model on PEMS-BAY.
This commit is contained in:
parent
ad36deb794
commit
763eb7af69
|
|
@ -39,9 +39,13 @@ The generated train/val/test dataset will be saved at `data/{METR-LA,PEMS-BAY}/{
|
||||||
## Run the Pre-trained Model on METR-LA
|
## Run the Pre-trained Model on METR-LA
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python run_demo.py
|
# METR-LA
|
||||||
|
python run_demo.py --config_filename=data/model/pretrained/METR-LA/config.yaml
|
||||||
|
|
||||||
|
# PEMS-BAY
|
||||||
|
python run_demo.py --config_filename=data/model/pretrained/PEMS-BAY/config.yaml
|
||||||
```
|
```
|
||||||
The generated prediction of DCRNN of METR-LA is in `data/results/dcrnn_predictions_[1-12].h5`.
|
The generated prediction of DCRNN is in `data/results/dcrnn_predictions`.
|
||||||
|
|
||||||
|
|
||||||
## Model Training
|
## Model Training
|
||||||
|
|
|
||||||
|
|
@ -25,12 +25,12 @@ train:
|
||||||
epochs: 100
|
epochs: 100
|
||||||
epsilon: 0.001
|
epsilon: 0.001
|
||||||
global_step: 24375
|
global_step: 24375
|
||||||
log_dir: data/model/pretrained/
|
log_dir: data/model/pretrained/METR-LA
|
||||||
lr_decay_ratio: 0.1
|
lr_decay_ratio: 0.1
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
max_to_keep: 100
|
max_to_keep: 100
|
||||||
min_learning_rate: 2.0e-06
|
min_learning_rate: 2.0e-06
|
||||||
model_filename: data/model/pretrained/models-2.7422-24375
|
model_filename: data/model/pretrained/METR-LA/models-2.7422-24375
|
||||||
optimizer: adam
|
optimizer: adam
|
||||||
patience: 50
|
patience: 50
|
||||||
steps:
|
steps:
|
||||||
|
|
@ -0,0 +1,42 @@
|
||||||
|
base_dir: data/model
|
||||||
|
data:
|
||||||
|
batch_size: 64
|
||||||
|
dataset_dir: data/PEMS-BAY
|
||||||
|
graph_pkl_filename: data/sensor_graph/adj_mx_bay.pkl
|
||||||
|
test_batch_size: 64
|
||||||
|
val_batch_size: 64
|
||||||
|
log_level: INFO
|
||||||
|
model:
|
||||||
|
cl_decay_steps: 2000
|
||||||
|
filter_type: dual_random_walk
|
||||||
|
horizon: 12
|
||||||
|
input_dim: 2
|
||||||
|
l1_decay: 0
|
||||||
|
max_diffusion_step: 2
|
||||||
|
num_nodes: 325
|
||||||
|
num_rnn_layers: 2
|
||||||
|
output_dim: 1
|
||||||
|
rnn_units: 64
|
||||||
|
seq_len: 12
|
||||||
|
use_curriculum_learning: true
|
||||||
|
train:
|
||||||
|
base_lr: 0.01
|
||||||
|
dropout: 0
|
||||||
|
epoch: 53
|
||||||
|
epochs: 100
|
||||||
|
epsilon: 0.001
|
||||||
|
global_step: 30780
|
||||||
|
log_dir: data/model/pretrained/PEMS-BAY/
|
||||||
|
lr_decay_ratio: 0.1
|
||||||
|
max_grad_norm: 5
|
||||||
|
max_to_keep: 100
|
||||||
|
min_learning_rate: 2.0e-06
|
||||||
|
model_filename: data/model/pretrained/PEMS-BAY/models-1.6139-30780
|
||||||
|
optimizer: adam
|
||||||
|
patience: 50
|
||||||
|
steps:
|
||||||
|
- 20
|
||||||
|
- 30
|
||||||
|
- 40
|
||||||
|
- 50
|
||||||
|
test_every_n_epochs: 10
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -10,13 +10,13 @@ from model.dcrnn_supervisor import DCRNNSupervisor
|
||||||
|
|
||||||
|
|
||||||
def run_dcrnn(args):
|
def run_dcrnn(args):
|
||||||
graph_pkl_filename = 'data/sensor_graph/adj_mx.pkl'
|
|
||||||
with open(args.config_filename) as f:
|
with open(args.config_filename) as f:
|
||||||
config = yaml.load(f)
|
config = yaml.load(f)
|
||||||
tf_config = tf.ConfigProto()
|
tf_config = tf.ConfigProto()
|
||||||
if args.use_cpu_only:
|
if args.use_cpu_only:
|
||||||
tf_config = tf.ConfigProto(device_count={'GPU': 0})
|
tf_config = tf.ConfigProto(device_count={'GPU': 0})
|
||||||
tf_config.gpu_options.allow_growth = True
|
tf_config.gpu_options.allow_growth = True
|
||||||
|
graph_pkl_filename = config['data']['graph_pkl_filename']
|
||||||
_, _, adj_mx = load_graph_data(graph_pkl_filename)
|
_, _, adj_mx = load_graph_data(graph_pkl_filename)
|
||||||
with tf.Session(config=tf_config) as sess:
|
with tf.Session(config=tf_config) as sess:
|
||||||
supervisor = DCRNNSupervisor(adj_mx=adj_mx, **config)
|
supervisor = DCRNNSupervisor(adj_mx=adj_mx, **config)
|
||||||
|
|
@ -30,7 +30,7 @@ if __name__ == '__main__':
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--use_cpu_only', default=False, type=str, help='Whether to run tensorflow on cpu.')
|
parser.add_argument('--use_cpu_only', default=False, type=str, help='Whether to run tensorflow on cpu.')
|
||||||
parser.add_argument('--config_filename', default='data/model/pretrained/config.yaml', type=str,
|
parser.add_argument('--config_filename', default='data/model/pretrained/METR-LA/config.yaml', type=str,
|
||||||
help='Config file for pretrained model.')
|
help='Config file for pretrained model.')
|
||||||
parser.add_argument('--output_filename', default='data/dcrnn_predictions.npz')
|
parser.add_argument('--output_filename', default='data/dcrnn_predictions.npz')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue