Add pretrained model on PEMS-BAY.
This commit is contained in:
parent
ad36deb794
commit
763eb7af69
|
|
@ -39,9 +39,13 @@ The generated train/val/test dataset will be saved at `data/{METR-LA,PEMS-BAY}/{
|
|||
## Run the Pre-trained Model on METR-LA
|
||||
|
||||
```bash
|
||||
python run_demo.py
|
||||
# METR-LA
|
||||
python run_demo.py --config_filename=data/model/pretrained/METR-LA/config.yaml
|
||||
|
||||
# PEMS-BAY
|
||||
python run_demo.py --config_filename=data/model/pretrained/PEMS-BAY/config.yaml
|
||||
```
|
||||
The generated prediction of DCRNN of METR-LA is in `data/results/dcrnn_predictions_[1-12].h5`.
|
||||
The generated prediction of DCRNN is in `data/results/dcrnn_predictions`.
|
||||
|
||||
|
||||
## Model Training
|
||||
|
|
|
|||
|
|
@ -25,12 +25,12 @@ train:
|
|||
epochs: 100
|
||||
epsilon: 0.001
|
||||
global_step: 24375
|
||||
log_dir: data/model/pretrained/
|
||||
log_dir: data/model/pretrained/METR-LA
|
||||
lr_decay_ratio: 0.1
|
||||
max_grad_norm: 5
|
||||
max_to_keep: 100
|
||||
min_learning_rate: 2.0e-06
|
||||
model_filename: data/model/pretrained/models-2.7422-24375
|
||||
model_filename: data/model/pretrained/METR-LA/models-2.7422-24375
|
||||
optimizer: adam
|
||||
patience: 50
|
||||
steps:
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
base_dir: data/model
|
||||
data:
|
||||
batch_size: 64
|
||||
dataset_dir: data/PEMS-BAY
|
||||
graph_pkl_filename: data/sensor_graph/adj_mx_bay.pkl
|
||||
test_batch_size: 64
|
||||
val_batch_size: 64
|
||||
log_level: INFO
|
||||
model:
|
||||
cl_decay_steps: 2000
|
||||
filter_type: dual_random_walk
|
||||
horizon: 12
|
||||
input_dim: 2
|
||||
l1_decay: 0
|
||||
max_diffusion_step: 2
|
||||
num_nodes: 325
|
||||
num_rnn_layers: 2
|
||||
output_dim: 1
|
||||
rnn_units: 64
|
||||
seq_len: 12
|
||||
use_curriculum_learning: true
|
||||
train:
|
||||
base_lr: 0.01
|
||||
dropout: 0
|
||||
epoch: 53
|
||||
epochs: 100
|
||||
epsilon: 0.001
|
||||
global_step: 30780
|
||||
log_dir: data/model/pretrained/PEMS-BAY/
|
||||
lr_decay_ratio: 0.1
|
||||
max_grad_norm: 5
|
||||
max_to_keep: 100
|
||||
min_learning_rate: 2.0e-06
|
||||
model_filename: data/model/pretrained/PEMS-BAY/models-1.6139-30780
|
||||
optimizer: adam
|
||||
patience: 50
|
||||
steps:
|
||||
- 20
|
||||
- 30
|
||||
- 40
|
||||
- 50
|
||||
test_every_n_epochs: 10
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -10,13 +10,13 @@ from model.dcrnn_supervisor import DCRNNSupervisor
|
|||
|
||||
|
||||
def run_dcrnn(args):
|
||||
graph_pkl_filename = 'data/sensor_graph/adj_mx.pkl'
|
||||
with open(args.config_filename) as f:
|
||||
config = yaml.load(f)
|
||||
tf_config = tf.ConfigProto()
|
||||
if args.use_cpu_only:
|
||||
tf_config = tf.ConfigProto(device_count={'GPU': 0})
|
||||
tf_config.gpu_options.allow_growth = True
|
||||
graph_pkl_filename = config['data']['graph_pkl_filename']
|
||||
_, _, adj_mx = load_graph_data(graph_pkl_filename)
|
||||
with tf.Session(config=tf_config) as sess:
|
||||
supervisor = DCRNNSupervisor(adj_mx=adj_mx, **config)
|
||||
|
|
@ -30,7 +30,7 @@ if __name__ == '__main__':
|
|||
sys.path.append(os.getcwd())
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--use_cpu_only', default=False, type=str, help='Whether to run tensorflow on cpu.')
|
||||
parser.add_argument('--config_filename', default='data/model/pretrained/config.yaml', type=str,
|
||||
parser.add_argument('--config_filename', default='data/model/pretrained/METR-LA/config.yaml', type=str,
|
||||
help='Config file for pretrained model.')
|
||||
parser.add_argument('--output_filename', default='data/dcrnn_predictions.npz')
|
||||
args = parser.parse_args()
|
||||
|
|
|
|||
Loading…
Reference in New Issue