Add log level support and load_dataset method.
This commit is contained in:
parent
88d9fc86d1
commit
bbc06b6c0c
|
|
@ -1,5 +1,6 @@
|
|||
---
|
||||
base_dir: data/model
|
||||
log_level: INFO
|
||||
data:
|
||||
batch_size: 64
|
||||
dataset_dir: data/METR-LA
|
||||
|
|
|
|||
|
|
@ -1,36 +0,0 @@
|
|||
---
|
||||
base_dir: data/model
|
||||
data:
|
||||
batch_size: 64
|
||||
dataset_dir: data/METR-LA
|
||||
test_batch_size: 64
|
||||
val_batch_size: 64
|
||||
graph_pkl_filename: data/sensor_graph/adj_mx.pkl
|
||||
|
||||
model:
|
||||
cl_decay_steps: 2000
|
||||
filter_type: laplacian
|
||||
horizon: 12
|
||||
input_dim: 2
|
||||
l1_decay: 0
|
||||
max_diffusion_step: 2
|
||||
max_grad_norm: 5
|
||||
num_nodes: 207
|
||||
num_rnn_layers: 2
|
||||
output_dim: 1
|
||||
rnn_units: 16
|
||||
seq_len: 12
|
||||
use_curriculum_learning: true
|
||||
|
||||
train:
|
||||
base_lr: 0.01
|
||||
dropout: 0
|
||||
epoch: 0
|
||||
epochs: 100
|
||||
global_step: 0
|
||||
lr_decay_ratio: 0.1
|
||||
steps: [20, 30, 40, 50]
|
||||
max_to_keep: 100
|
||||
min_learning_rate: 2.0e-06
|
||||
patience: 50
|
||||
test_every_n_epochs: 10
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
---
|
||||
base_dir: data/model
|
||||
log_level: INFO
|
||||
data:
|
||||
batch_size: 64
|
||||
dataset_dir: data/METR-LA
|
||||
|
|
@ -14,11 +15,10 @@ model:
|
|||
input_dim: 2
|
||||
l1_decay: 0
|
||||
max_diffusion_step: 2
|
||||
max_grad_norm: 5
|
||||
num_nodes: 207
|
||||
num_rnn_layers: 2
|
||||
output_dim: 1
|
||||
rnn_units: 64
|
||||
rnn_units: 16
|
||||
seq_len: 12
|
||||
use_curriculum_learning: true
|
||||
|
||||
|
|
@ -27,10 +27,13 @@ train:
|
|||
dropout: 0
|
||||
epoch: 0
|
||||
epochs: 100
|
||||
epsilon: 1.0e-3
|
||||
global_step: 0
|
||||
lr_decay_ratio: 0.1
|
||||
steps: [20, 30, 40, 50]
|
||||
max_grad_norm: 5
|
||||
max_to_keep: 100
|
||||
min_learning_rate: 2.0e-06
|
||||
optimizer: adam
|
||||
patience: 50
|
||||
steps: [20, 30, 40, 50]
|
||||
test_every_n_epochs: 10
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
base_dir: data/model
|
||||
log_level: INFO
|
||||
data:
|
||||
batch_size: 64
|
||||
dataset_dir: data/METR-LA
|
||||
|
|
|
|||
23
lib/utils.py
23
lib/utils.py
|
|
@ -145,9 +145,9 @@ def config_logging(log_dir, log_filename='info.log', level=logging.INFO):
|
|||
logging.basicConfig(handlers=[file_handler, console_handler], level=level)
|
||||
|
||||
|
||||
def get_logger(log_dir, name, log_filename='info.log'):
|
||||
def get_logger(log_dir, name, log_filename='info.log', level=logging.INFO):
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.setLevel(level)
|
||||
# Add file handler and stdout handler
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
file_handler = logging.FileHandler(os.path.join(log_dir, log_filename))
|
||||
|
|
@ -175,6 +175,25 @@ def get_total_trainable_parameter_size():
|
|||
return total_parameters
|
||||
|
||||
|
||||
def load_dataset(dataset_dir, batch_size, test_batch_size=None, **kwargs):
|
||||
data = {}
|
||||
for category in ['train', 'val', 'test']:
|
||||
cat_data = np.load(os.path.join(dataset_dir, category + '.npz'))
|
||||
data['x_' + category] = cat_data['x']
|
||||
data['y_' + category] = cat_data['y']
|
||||
scaler = StandardScaler(mean=data['x_train'][..., 0].mean(), std=data['x_train'][..., 0].std())
|
||||
# Data format
|
||||
for category in ['train', 'val', 'test']:
|
||||
data['x_' + category][..., 0] = scaler.transform(data['x_' + category][..., 0])
|
||||
data['y_' + category][..., 0] = scaler.transform(data['y_' + category][..., 0])
|
||||
data['train_loader'] = DataLoader(data['x_train'], data['y_train'], batch_size, shuffle=True)
|
||||
data['val_loader'] = DataLoader(data['x_val'], data['y_val'], test_batch_size, shuffle=False)
|
||||
data['test_loader'] = DataLoader(data['x_test'], data['y_test'], test_batch_size, shuffle=False)
|
||||
data['scaler'] = scaler
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def load_graph_data(pkl_filename):
|
||||
sensor_ids, sensor_id_to_ind, adj_mx = load_pickle(pkl_filename)
|
||||
return sensor_ids, sensor_id_to_ind, adj_mx
|
||||
|
|
|
|||
Loading…
Reference in New Issue