Update the pretrained model and results.

This commit is contained in:
Yaguang 2018-06-07 11:18:55 +08:00
parent e08002d72d
commit 17927239a7
19 changed files with 43 additions and 103 deletions

View File

@ -1,18 +1,22 @@
---
base_dir: data/model base_dir: data/model
batch_size: 64 batch_size: 64
cl_decay_steps: 2000 cl_decay_steps: 2000
data_type: ALL data_type: ALL
dropout: 0 dropout: 0
epoch: 100 epoch: 75
epochs: 100 epochs: 100
filter_type: dual_random_walk filter_type: dual_random_walk
global_step: 35451 global_step: !!python/object/apply:numpy.core.multiarray.scalar
- !!python/object/apply:numpy.dtype
args: [i8, 0, 1]
state: !!python/tuple [3, <, null, null, null, -1, -1, 0]
- !!binary |
NGgAAAAAAAA=
graph_pkl_filename: data/sensor_graph/adj_mx.pkl graph_pkl_filename: data/sensor_graph/adj_mx.pkl
horizon: 12 horizon: 12
l1_decay: 0 l1_decay: 0
learning_rate: 0.01 learning_rate: 0.01
log_dir: data/model/dcrnn_DR_2_h_12_64-64_lr_0.01_bs_64_d_0.00_sl_12_MAE_1207002222/ log_dir: data/model/dcrnn_DR_2_h_12_64-64_lr_0.01_bs_64_d_0.00_sl_12_MAE_0606021843/
loss_func: MAE loss_func: MAE
lr_decay: 0.1 lr_decay: 0.1
lr_decay_epoch: 20 lr_decay_epoch: 20
@ -20,9 +24,10 @@ lr_decay_interval: 10
max_diffusion_step: 2 max_diffusion_step: 2
max_grad_norm: 5 max_grad_norm: 5
min_learning_rate: 2.0e-06 min_learning_rate: 2.0e-06
model_filename: data/model/dcrnn_DR_2_h_12_64-64_lr_0.01_bs_64_d_0.00_sl_12_MAE_1207002222/models-1.6253-35451 model_filename: data/model/dcrnn_DR_2_h_12_64-64_lr_0.01_bs_64_d_0.00_sl_12_MAE_0606021843/models-2.8476-26676
null_val: 0 null_val: 0
num_rnn_layers: 2 num_rnn_layers: 2
output_dim: 1
patience: 50 patience: 50
rnn_units: 64 rnn_units: 64
seq_len: 12 seq_len: 12

View File

@ -1,66 +0,0 @@
model_checkpoint_path: "models-1.6253-35451"
all_model_checkpoint_paths: "models-2.9323-351"
all_model_checkpoint_paths: "models-2.2916-702"
all_model_checkpoint_paths: "models-2.1618-1404"
all_model_checkpoint_paths: "models-2.1094-1755"
all_model_checkpoint_paths: "models-2.0356-2106"
all_model_checkpoint_paths: "models-2.0139-2808"
all_model_checkpoint_paths: "models-1.9127-3159"
all_model_checkpoint_paths: "models-1.8968-4914"
all_model_checkpoint_paths: "models-1.8671-5265"
all_model_checkpoint_paths: "models-1.8386-7371"
all_model_checkpoint_paths: "models-1.7334-7722"
all_model_checkpoint_paths: "models-1.7301-8073"
all_model_checkpoint_paths: "models-1.7291-8424"
all_model_checkpoint_paths: "models-1.7214-8775"
all_model_checkpoint_paths: "models-1.7164-9477"
all_model_checkpoint_paths: "models-1.7163-10530"
all_model_checkpoint_paths: "models-1.6611-11232"
all_model_checkpoint_paths: "models-1.6586-11583"
all_model_checkpoint_paths: "models-1.6576-11934"
all_model_checkpoint_paths: "models-1.6554-12636"
all_model_checkpoint_paths: "models-1.6552-13338"
all_model_checkpoint_paths: "models-1.6540-13689"
all_model_checkpoint_paths: "models-1.6526-14391"
all_model_checkpoint_paths: "models-1.6425-14742"
all_model_checkpoint_paths: "models-1.6415-15444"
all_model_checkpoint_paths: "models-1.6385-15795"
all_model_checkpoint_paths: "models-1.6377-16497"
all_model_checkpoint_paths: "models-1.6358-16848"
all_model_checkpoint_paths: "models-1.6358-17901"
all_model_checkpoint_paths: "models-1.6284-18252"
all_model_checkpoint_paths: "models-1.6284-18603"
all_model_checkpoint_paths: "models-1.6282-18954"
all_model_checkpoint_paths: "models-1.6281-19305"
all_model_checkpoint_paths: "models-1.6275-19656"
all_model_checkpoint_paths: "models-1.6273-20007"
all_model_checkpoint_paths: "models-1.6273-20709"
all_model_checkpoint_paths: "models-1.6271-21060"
all_model_checkpoint_paths: "models-1.6266-21411"
all_model_checkpoint_paths: "models-1.6265-21762"
all_model_checkpoint_paths: "models-1.6263-23166"
all_model_checkpoint_paths: "models-1.6262-24570"
all_model_checkpoint_paths: "models-1.6261-24921"
all_model_checkpoint_paths: "models-1.6259-25974"
all_model_checkpoint_paths: "models-1.6259-27378"
all_model_checkpoint_paths: "models-1.6259-27729"
all_model_checkpoint_paths: "models-1.6259-28080"
all_model_checkpoint_paths: "models-1.6259-28431"
all_model_checkpoint_paths: "models-1.6258-28782"
all_model_checkpoint_paths: "models-1.6258-29484"
all_model_checkpoint_paths: "models-1.6258-29835"
all_model_checkpoint_paths: "models-1.6257-30186"
all_model_checkpoint_paths: "models-1.6257-30537"
all_model_checkpoint_paths: "models-1.6257-30888"
all_model_checkpoint_paths: "models-1.6256-31239"
all_model_checkpoint_paths: "models-1.6256-31941"
all_model_checkpoint_paths: "models-1.6255-32292"
all_model_checkpoint_paths: "models-1.6255-32643"
all_model_checkpoint_paths: "models-1.6255-32994"
all_model_checkpoint_paths: "models-1.6254-33345"
all_model_checkpoint_paths: "models-1.6254-33696"
all_model_checkpoint_paths: "models-1.6254-34047"
all_model_checkpoint_paths: "models-1.6253-34398"
all_model_checkpoint_paths: "models-1.6253-34749"
all_model_checkpoint_paths: "models-1.6253-35100"
all_model_checkpoint_paths: "models-1.6253-35451"

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -14,11 +14,12 @@ flags.DEFINE_bool('use_cpu_only', False, 'Whether to run tensorflow on cpu.')
def run_dcrnn(traffic_reading_df): def run_dcrnn(traffic_reading_df):
run_id = 'dcrnn_DR_2_h_12_64-64_lr_0.01_bs_64_d_0.00_sl_12_MAE_1207002222' # run_id = 'dcrnn_DR_2_h_12_64-64_lr_0.01_bs_64_d_0.00_sl_12_MAE_1207002222'
run_id = 'dcrnn_DR_2_h_12_64-64_lr_0.01_bs_64_d_0.00_sl_12_MAE_0606021843'
log_dir = os.path.join('data/model', run_id) log_dir = os.path.join('data/model', run_id)
config_filename = 'config_100.yaml' config_filename = 'config_75.yaml'
graph_pkl_filename = 'data/sensor_graph/adj_mx.pkl' graph_pkl_filename = 'data/sensor_graph/adj_mx.pkl'
with open(os.path.join(log_dir, config_filename)) as f: with open(os.path.join(log_dir, config_filename)) as f:
config = yaml.load(f) config = yaml.load(f)