diff --git a/data/model/dcrnn_config.yaml b/data/model/dcrnn_config.yaml index 27b2bf0..54a1e51 100644 --- a/data/model/dcrnn_config.yaml +++ b/data/model/dcrnn_config.yaml @@ -1,5 +1,6 @@ --- base_dir: data/model +log_level: INFO data: batch_size: 64 dataset_dir: data/METR-LA diff --git a/data/model/dcrnn_config_u16_lap.yaml b/data/model/dcrnn_config_u16_lap.yaml deleted file mode 100644 index d473341..0000000 --- a/data/model/dcrnn_config_u16_lap.yaml +++ /dev/null @@ -1,36 +0,0 @@ ---- -base_dir: data/model -data: - batch_size: 64 - dataset_dir: data/METR-LA - test_batch_size: 64 - val_batch_size: 64 - graph_pkl_filename: data/sensor_graph/adj_mx.pkl - -model: - cl_decay_steps: 2000 - filter_type: laplacian - horizon: 12 - input_dim: 2 - l1_decay: 0 - max_diffusion_step: 2 - max_grad_norm: 5 - num_nodes: 207 - num_rnn_layers: 2 - output_dim: 1 - rnn_units: 16 - seq_len: 12 - use_curriculum_learning: true - -train: - base_lr: 0.01 - dropout: 0 - epoch: 0 - epochs: 100 - global_step: 0 - lr_decay_ratio: 0.1 - steps: [20, 30, 40, 50] - max_to_keep: 100 - min_learning_rate: 2.0e-06 - patience: 50 - test_every_n_epochs: 10 \ No newline at end of file diff --git a/data/model/dcrnn_test_config.yaml b/data/model/dcrnn_test_config.yaml index 60b1d5a..fce08bc 100644 --- a/data/model/dcrnn_test_config.yaml +++ b/data/model/dcrnn_test_config.yaml @@ -1,5 +1,6 @@ --- base_dir: data/model +log_level: INFO data: batch_size: 64 dataset_dir: data/METR-LA @@ -14,11 +15,10 @@ model: input_dim: 2 l1_decay: 0 max_diffusion_step: 2 - max_grad_norm: 5 num_nodes: 207 num_rnn_layers: 2 output_dim: 1 - rnn_units: 64 + rnn_units: 16 seq_len: 12 use_curriculum_learning: true @@ -27,10 +27,13 @@ train: dropout: 0 epoch: 0 epochs: 100 + epsilon: 1.0e-3 global_step: 0 lr_decay_ratio: 0.1 - steps: [20, 30, 40, 50] + max_grad_norm: 5 max_to_keep: 100 min_learning_rate: 2.0e-06 + optimizer: adam patience: 50 + steps: [20, 30, 40, 50] test_every_n_epochs: 10 \ No newline at end of file diff --git a/data/model/pretrained/config.yaml b/data/model/pretrained/config.yaml index 621ba8d..5292f70 100644 --- a/data/model/pretrained/config.yaml +++ b/data/model/pretrained/config.yaml @@ -1,4 +1,5 @@ base_dir: data/model +log_level: INFO data: batch_size: 64 dataset_dir: data/METR-LA diff --git a/lib/utils.py b/lib/utils.py index 4933cac..82dd538 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -145,9 +145,9 @@ def config_logging(log_dir, log_filename='info.log', level=logging.INFO): logging.basicConfig(handlers=[file_handler, console_handler], level=level) -def get_logger(log_dir, name, log_filename='info.log'): +def get_logger(log_dir, name, log_filename='info.log', level=logging.INFO): logger = logging.getLogger(name) - logger.setLevel(logging.DEBUG) + logger.setLevel(level) # Add file handler and stdout handler formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler = logging.FileHandler(os.path.join(log_dir, log_filename)) @@ -175,6 +175,25 @@ def get_total_trainable_parameter_size(): return total_parameters +def load_dataset(dataset_dir, batch_size, test_batch_size=None, **kwargs): + data = {} + for category in ['train', 'val', 'test']: + cat_data = np.load(os.path.join(dataset_dir, category + '.npz')) + data['x_' + category] = cat_data['x'] + data['y_' + category] = cat_data['y'] + scaler = StandardScaler(mean=data['x_train'][..., 0].mean(), std=data['x_train'][..., 0].std()) + # Data format + for category in ['train', 'val', 'test']: + data['x_' + category][..., 0] = scaler.transform(data['x_' + category][..., 0]) + data['y_' + category][..., 0] = scaler.transform(data['y_' + category][..., 0]) + data['train_loader'] = DataLoader(data['x_train'], data['y_train'], batch_size, shuffle=True) + data['val_loader'] = DataLoader(data['x_val'], data['y_val'], test_batch_size, shuffle=False) + data['test_loader'] = DataLoader(data['x_test'], data['y_test'], test_batch_size, shuffle=False) + data['scaler'] = scaler + + return data + + def load_graph_data(pkl_filename): sensor_ids, sensor_id_to_ind, adj_mx = load_pickle(pkl_filename) return sensor_ids, sensor_id_to_ind, adj_mx