Update pretrained model.
This commit is contained in:
parent
e0212cc178
commit
9520e6cf85
|
|
@ -14,7 +14,6 @@ model:
|
||||||
input_dim: 2
|
input_dim: 2
|
||||||
l1_decay: 0
|
l1_decay: 0
|
||||||
max_diffusion_step: 2
|
max_diffusion_step: 2
|
||||||
max_grad_norm: 5
|
|
||||||
num_nodes: 207
|
num_nodes: 207
|
||||||
num_rnn_layers: 2
|
num_rnn_layers: 2
|
||||||
output_dim: 1
|
output_dim: 1
|
||||||
|
|
@ -27,10 +26,13 @@ train:
|
||||||
dropout: 0
|
dropout: 0
|
||||||
epoch: 0
|
epoch: 0
|
||||||
epochs: 100
|
epochs: 100
|
||||||
|
epsilon: 1.0e-3
|
||||||
global_step: 0
|
global_step: 0
|
||||||
lr_decay_ratio: 0.1
|
lr_decay_ratio: 0.1
|
||||||
steps: [20, 30, 40, 50]
|
max_grad_norm: 5
|
||||||
max_to_keep: 100
|
max_to_keep: 100
|
||||||
min_learning_rate: 2.0e-06
|
min_learning_rate: 2.0e-06
|
||||||
|
optimizer: adam
|
||||||
patience: 50
|
patience: 50
|
||||||
|
steps: [20, 30, 40, 50]
|
||||||
test_every_n_epochs: 10
|
test_every_n_epochs: 10
|
||||||
|
|
@ -0,0 +1,40 @@
|
||||||
|
base_dir: data/model
|
||||||
|
data:
|
||||||
|
batch_size: 64
|
||||||
|
dataset_dir: data/METR-LA
|
||||||
|
graph_pkl_filename: data/sensor_graph/adj_mx.pkl
|
||||||
|
test_batch_size: 64
|
||||||
|
model:
|
||||||
|
cl_decay_steps: 2000
|
||||||
|
filter_type: dual_random_walk
|
||||||
|
horizon: 12
|
||||||
|
input_dim: 2
|
||||||
|
l1_decay: 0
|
||||||
|
max_diffusion_step: 2
|
||||||
|
num_nodes: 207
|
||||||
|
num_rnn_layers: 2
|
||||||
|
output_dim: 1
|
||||||
|
rnn_units: 64
|
||||||
|
seq_len: 12
|
||||||
|
use_curriculum_learning: true
|
||||||
|
train:
|
||||||
|
base_lr: 0.01
|
||||||
|
dropout: 0
|
||||||
|
epoch: 64
|
||||||
|
epochs: 100
|
||||||
|
epsilon: 0.001
|
||||||
|
global_step: 24375
|
||||||
|
log_dir: data/model/pretrained/
|
||||||
|
lr_decay_ratio: 0.1
|
||||||
|
max_grad_norm: 5
|
||||||
|
max_to_keep: 100
|
||||||
|
min_learning_rate: 2.0e-06
|
||||||
|
model_filename: data/model/pretrained/models-2.7422-24375
|
||||||
|
optimizer: adam
|
||||||
|
patience: 50
|
||||||
|
steps:
|
||||||
|
- 20
|
||||||
|
- 30
|
||||||
|
- 40
|
||||||
|
- 50
|
||||||
|
test_every_n_epochs: 10
|
||||||
Binary file not shown.
Binary file not shown.
18
run_demo.py
18
run_demo.py
|
|
@ -1,6 +1,6 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
import pandas as pd
|
|
||||||
import sys
|
import sys
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import yaml
|
import yaml
|
||||||
|
|
@ -21,21 +21,17 @@ def run_dcrnn(args):
|
||||||
with tf.Session(config=tf_config) as sess:
|
with tf.Session(config=tf_config) as sess:
|
||||||
supervisor = DCRNNSupervisor(adj_mx=adj_mx, **config)
|
supervisor = DCRNNSupervisor(adj_mx=adj_mx, **config)
|
||||||
supervisor.restore(sess, config=config)
|
supervisor.restore(sess, config=config)
|
||||||
df_preds = supervisor.test_and_write_result(sess, config['global_step'])
|
outputs = supervisor.test_and_write_result(sess, config['train']['global_step'])
|
||||||
# TODO (yaguang): save this file to the npz file.
|
np.savez_compressed(args.output_filename, **outputs)
|
||||||
for horizon_i in df_preds:
|
print('Predictions saved as {}.'.format(args.output_filename))
|
||||||
df_pred = df_preds[horizon_i]
|
|
||||||
filename = os.path.join('data/results/', 'dcrnn_prediction_%d.h5' % (horizon_i + 1))
|
|
||||||
df_pred.to_hdf(filename, 'results')
|
|
||||||
print('Predictions saved as data/results/dcrnn_seq2seq_prediction_[1-12].h5...')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--traffic_df_filename', default='data/df_highway_2012_4mon_sample.h5',
|
|
||||||
type=str, help='Traffic data file.')
|
|
||||||
parser.add_argument('--use_cpu_only', default=False, type=str, help='Whether to run tensorflow on cpu.')
|
parser.add_argument('--use_cpu_only', default=False, type=str, help='Whether to run tensorflow on cpu.')
|
||||||
parser.add_argument('--config_filename', default=None, type=str, help='Config file for pretrained model.')
|
parser.add_argument('--config_filename', default='data/model/pretrained/config.yaml', type=str,
|
||||||
|
help='Config file for pretrained model.')
|
||||||
|
parser.add_argument('--output_filename', default='data/dcrnn_predictions.npz')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
run_dcrnn(args)
|
run_dcrnn(args)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue