Update pretrained model.
This commit is contained in:
parent
e0212cc178
commit
9520e6cf85
|
|
@ -14,7 +14,6 @@ model:
|
|||
input_dim: 2
|
||||
l1_decay: 0
|
||||
max_diffusion_step: 2
|
||||
max_grad_norm: 5
|
||||
num_nodes: 207
|
||||
num_rnn_layers: 2
|
||||
output_dim: 1
|
||||
|
|
@ -27,10 +26,13 @@ train:
|
|||
dropout: 0
|
||||
epoch: 0
|
||||
epochs: 100
|
||||
epsilon: 1.0e-3
|
||||
global_step: 0
|
||||
lr_decay_ratio: 0.1
|
||||
steps: [20, 30, 40, 50]
|
||||
max_grad_norm: 5
|
||||
max_to_keep: 100
|
||||
min_learning_rate: 2.0e-06
|
||||
optimizer: adam
|
||||
patience: 50
|
||||
steps: [20, 30, 40, 50]
|
||||
test_every_n_epochs: 10
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
base_dir: data/model
|
||||
data:
|
||||
batch_size: 64
|
||||
dataset_dir: data/METR-LA
|
||||
graph_pkl_filename: data/sensor_graph/adj_mx.pkl
|
||||
test_batch_size: 64
|
||||
model:
|
||||
cl_decay_steps: 2000
|
||||
filter_type: dual_random_walk
|
||||
horizon: 12
|
||||
input_dim: 2
|
||||
l1_decay: 0
|
||||
max_diffusion_step: 2
|
||||
num_nodes: 207
|
||||
num_rnn_layers: 2
|
||||
output_dim: 1
|
||||
rnn_units: 64
|
||||
seq_len: 12
|
||||
use_curriculum_learning: true
|
||||
train:
|
||||
base_lr: 0.01
|
||||
dropout: 0
|
||||
epoch: 64
|
||||
epochs: 100
|
||||
epsilon: 0.001
|
||||
global_step: 24375
|
||||
log_dir: data/model/pretrained/
|
||||
lr_decay_ratio: 0.1
|
||||
max_grad_norm: 5
|
||||
max_to_keep: 100
|
||||
min_learning_rate: 2.0e-06
|
||||
model_filename: data/model/pretrained/models-2.7422-24375
|
||||
optimizer: adam
|
||||
patience: 50
|
||||
steps:
|
||||
- 20
|
||||
- 30
|
||||
- 40
|
||||
- 50
|
||||
test_every_n_epochs: 10
|
||||
Binary file not shown.
Binary file not shown.
18
run_demo.py
18
run_demo.py
|
|
@ -1,6 +1,6 @@
|
|||
import argparse
|
||||
import numpy as np
|
||||
import os
|
||||
import pandas as pd
|
||||
import sys
|
||||
import tensorflow as tf
|
||||
import yaml
|
||||
|
|
@ -21,21 +21,17 @@ def run_dcrnn(args):
|
|||
with tf.Session(config=tf_config) as sess:
|
||||
supervisor = DCRNNSupervisor(adj_mx=adj_mx, **config)
|
||||
supervisor.restore(sess, config=config)
|
||||
df_preds = supervisor.test_and_write_result(sess, config['global_step'])
|
||||
# TODO (yaguang): save this file to the npz file.
|
||||
for horizon_i in df_preds:
|
||||
df_pred = df_preds[horizon_i]
|
||||
filename = os.path.join('data/results/', 'dcrnn_prediction_%d.h5' % (horizon_i + 1))
|
||||
df_pred.to_hdf(filename, 'results')
|
||||
print('Predictions saved as data/results/dcrnn_seq2seq_prediction_[1-12].h5...')
|
||||
outputs = supervisor.test_and_write_result(sess, config['train']['global_step'])
|
||||
np.savez_compressed(args.output_filename, **outputs)
|
||||
print('Predictions saved as {}.'.format(args.output_filename))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.path.append(os.getcwd())
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--traffic_df_filename', default='data/df_highway_2012_4mon_sample.h5',
|
||||
type=str, help='Traffic data file.')
|
||||
parser.add_argument('--use_cpu_only', default=False, type=str, help='Whether to run tensorflow on cpu.')
|
||||
parser.add_argument('--config_filename', default=None, type=str, help='Config file for pretrained model.')
|
||||
parser.add_argument('--config_filename', default='data/model/pretrained/config.yaml', type=str,
|
||||
help='Config file for pretrained model.')
|
||||
parser.add_argument('--output_filename', default='data/dcrnn_predictions.npz')
|
||||
args = parser.parse_args()
|
||||
run_dcrnn(args)
|
||||
|
|
|
|||
Loading…
Reference in New Issue