STDEN/model/stden_supervisor.py

415 lines
17 KiB
Python

import os
import time
from random import SystemRandom
import numpy as np
import pandas as pd
import torch
from torch.utils.tensorboard import SummaryWriter
from lib import utils
from model.stden_model import STDENModel
from lib.metrics import masked_mae_loss, masked_mape_loss, masked_mse_loss, masked_rmse_loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class STDENSupervisor:
def __init__(self, adj_mx, **kwargs):
self._kwargs = kwargs
self._data_kwargs = kwargs.get('data')
self._model_kwargs = kwargs.get('model')
self._train_kwargs = kwargs.get('train')
self.max_grad_norm = self._train_kwargs.get('max_grad_norm', 1.)
# logging.
self._log_dir = utils.get_log_dir(kwargs)
self._writer = SummaryWriter('runs/' + self._log_dir)
log_level = self._kwargs.get('log_level', 'INFO')
self._logger = utils.get_logger(self._log_dir, __name__, 'info.log', level=log_level)
# data set
self._data = utils.load_dataset(**self._data_kwargs)
self.standard_scaler = self._data['scaler']
self._logger.info('Scaler mean: {:.6f}, std {:.6f}.'.format(self.standard_scaler.mean, self.standard_scaler.std))
self.num_edges = (adj_mx > 0.).sum()
self.input_dim = int(self._model_kwargs.get('input_dim', 1))
self.seq_len = int(self._model_kwargs.get('seq_len')) # for the encoder
self.output_dim = int(self._model_kwargs.get('output_dim', 1))
self.use_curriculum_learning = bool(
self._model_kwargs.get('use_curriculum_learning', False))
self.horizon = int(self._model_kwargs.get('horizon', 1)) # for the decoder
# setup model
stden_model = STDENModel(adj_mx, self._logger, **self._model_kwargs)
self.stden_model = stden_model.cuda() if torch.cuda.is_available() else stden_model
self._logger.info("Model created")
self.experimentID = self._train_kwargs.get('load', 0)
if self.experimentID == 0:
# Make a new experiment ID
self.experimentID = int(SystemRandom().random()*100000)
self.ckpt_path = os.path.join("ckpt/", "experiment_" + str(self.experimentID))
self._epoch_num = self._train_kwargs.get('epoch', 0)
if self._epoch_num > 0:
self._logger.info('Loading model...')
self.load_model()
def save_model(self, epoch):
model_dir = self.ckpt_path
if not os.path.exists(model_dir):
os.makedirs(model_dir)
config = dict(self._kwargs)
config['model_state_dict'] = self.stden_model.state_dict()
config['epoch'] = epoch
model_path = os.path.join(model_dir, 'epo{}.tar'.format(epoch))
torch.save(config, model_path)
self._logger.info("Saved model at {}".format(epoch))
return model_path
def load_model(self):
self._setup_graph()
model_path = os.path.join(self.ckpt_path, 'epo{}.tar'.format(self._epoch_num))
assert os.path.exists(model_path), 'Weights at epoch %d not found' % self._epoch_num
checkpoint = torch.load(model_path, map_location='cpu')
self.stden_model.load_state_dict(checkpoint['model_state_dict'])
self._logger.info("Loaded model at {}".format(self._epoch_num))
def _setup_graph(self):
with torch.no_grad():
self.stden_model.eval()
val_iterator = self._data['val_loader'].get_iterator()
for _, (x, y) in enumerate(val_iterator):
x, y = self._prepare_data(x, y)
output = self.stden_model(x)
break
def train(self, **kwargs):
self._logger.info('Model mode: train')
kwargs.update(self._train_kwargs)
return self._train(**kwargs)
def _train(self, base_lr,
steps, patience=50, epochs=100, lr_decay_ratio=0.1, log_every=1, save_model=1,
test_every_n_epochs=10, epsilon=1e-8, **kwargs):
# steps is used in learning rate - will see if need to use it?
min_val_loss = float('inf')
wait = 0
optimizer = torch.optim.Adam(self.stden_model.parameters(), lr=base_lr, eps=epsilon)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=steps,
gamma=lr_decay_ratio)
self._logger.info('Start training ...')
# this will fail if model is loaded with a changed batch_size
num_batches = self._data['train_loader'].num_batch
self._logger.info("num_batches: {}".format(num_batches))
batches_seen = num_batches * self._epoch_num
# used for nfe
c = []
res, keys = [], []
for epoch_num in range(self._epoch_num, epochs):
self.stden_model.train()
train_iterator = self._data['train_loader'].get_iterator()
losses = []
start_time = time.time()
c.clear() #nfe
for i, (x, y) in enumerate(train_iterator):
if(i >= num_batches):
break
optimizer.zero_grad()
x, y = self._prepare_data(x, y)
output, fe = self.stden_model(x, y, batches_seen)
if batches_seen == 0:
# this is a workaround to accommodate dynamically registered parameters
optimizer = torch.optim.Adam(self.stden_model.parameters(), lr=base_lr, eps=epsilon)
loss = self._compute_loss(y, output)
self._logger.debug("FE: number - {}, time - {:.3f} s, err - {:.3f}".format(*fe, loss.item()))
c.append([*fe, loss.item()])
self._logger.debug(loss.item())
losses.append(loss.item())
batches_seen += 1 # global step in tensorboard
loss.backward()
# gradient clipping
torch.nn.utils.clip_grad_norm_(self.stden_model.parameters(), self.max_grad_norm)
optimizer.step()
del x, y, output, loss # del make these memory no-labeled trash
torch.cuda.empty_cache() # empty_cache() recycle no-labeled trash
# used for nfe
res.append(pd.DataFrame(c, columns=['nfe', 'time', 'err']))
keys.append(epoch_num)
self._logger.info("epoch complete")
lr_scheduler.step()
self._logger.info("evaluating now!")
val_loss, _ = self.evaluate(dataset='val', batches_seen=batches_seen)
end_time = time.time()
self._writer.add_scalar('training loss',
np.mean(losses),
batches_seen)
if (epoch_num % log_every) == log_every - 1:
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f}, lr: {:.6f}, ' \
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
np.mean(losses), val_loss, lr_scheduler.get_lr()[0],
(end_time - start_time))
self._logger.info(message)
if (epoch_num % test_every_n_epochs) == test_every_n_epochs - 1:
test_loss, _ = self.evaluate(dataset='test', batches_seen=batches_seen)
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f}, lr: {:.6f}, ' \
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
np.mean(losses), test_loss, lr_scheduler.get_lr()[0],
(end_time - start_time))
self._logger.info(message)
if val_loss < min_val_loss:
wait = 0
if save_model:
model_file_name = self.save_model(epoch_num)
self._logger.info(
'Val loss decrease from {:.4f} to {:.4f}, '
'saving to {}'.format(min_val_loss, val_loss, model_file_name))
min_val_loss = val_loss
elif val_loss >= min_val_loss:
wait += 1
if wait == patience:
self._logger.warning('Early stopping at epoch: %d' % epoch_num)
break
if bool(self._model_kwargs.get('nfe', False)):
res = pd.concat(res, keys=keys)
# self._logger.info("res.shape: ", res.shape)
res.index.names = ['epoch', 'iter']
filter_type = self._model_kwargs.get('filter_type', 'unknown')
atol = float(self._model_kwargs.get('odeint_atol', 1e-5))
rtol = float(self._model_kwargs.get('odeint_rtol', 1e-5))
nfe_file = os.path.join(
self._data_kwargs.get('dataset_dir', 'data'),
'nfe_{}_a{}_r{}.pkl'.format(filter_type, int(atol*1e5), int(rtol*1e5)))
res.to_pickle(nfe_file)
# res.to_csv(nfe_file)
def _prepare_data(self, x, y):
x, y = self._get_x_y(x, y)
x, y = self._get_x_y_in_correct_dims(x, y)
return x.to(device), y.to(device)
def _get_x_y(self, x, y):
"""
:param x: shape (batch_size, seq_len, num_edges, input_dim)
:param y: shape (batch_size, horizon, num_edges, input_dim)
:returns x shape (seq_len, batch_size, num_edges, input_dim)
y shape (horizon, batch_size, num_edges, input_dim)
"""
x = torch.from_numpy(x).float()
y = torch.from_numpy(y).float()
self._logger.debug("X: {}".format(x.size()))
self._logger.debug("y: {}".format(y.size()))
x = x.permute(1, 0, 2, 3)
y = y.permute(1, 0, 2, 3)
return x, y
def _get_x_y_in_correct_dims(self, x, y):
"""
:param x: shape (seq_len, batch_size, num_edges, input_dim)
:param y: shape (horizon, batch_size, num_edges, input_dim)
:return: x: shape (seq_len, batch_size, num_edges * input_dim)
y: shape (horizon, batch_size, num_edges * output_dim)
"""
batch_size = x.size(1)
self._logger.debug("size of x {}".format(x.size()))
x = x.view(self.seq_len, batch_size, self.num_edges * self.input_dim)
y = y[..., :self.output_dim].view(self.horizon, batch_size,
self.num_edges * self.output_dim)
return x, y
def _compute_loss(self, y_true, y_predicted):
y_true = self.standard_scaler.inverse_transform(y_true)
y_predicted = self.standard_scaler.inverse_transform(y_predicted)
return masked_mae_loss(y_predicted, y_true)
def _compute_loss_eval(self, y_true, y_predicted):
y_true = self.standard_scaler.inverse_transform(y_true)
y_predicted = self.standard_scaler.inverse_transform(y_predicted)
return masked_mae_loss(y_predicted, y_true).item(), masked_mape_loss(y_predicted, y_true).item(), masked_rmse_loss(y_predicted, y_true).item()
def evaluate(self, dataset='val', batches_seen=0, save=False):
"""
Computes mae rmse mape loss and the predict if save
:return: mean L1Loss
"""
with torch.no_grad():
self.stden_model.eval()
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
mae_losses = []
mape_losses = []
rmse_losses = []
y_dict = None
if(save):
y_truths = []
y_preds = []
for _, (x, y) in enumerate(val_iterator):
x, y = self._prepare_data(x, y)
output, fe = self.stden_model(x)
mae, mape, rmse = self._compute_loss_eval(y, output)
mae_losses.append(mae)
mape_losses.append(mape)
rmse_losses.append(rmse)
if(save):
y_truths.append(y.cpu())
y_preds.append(output.cpu())
mean_loss = {
'mae': np.mean(mae_losses),
'mape': np.mean(mape_losses),
'rmse': np.mean(rmse_losses)
}
self._logger.info('Evaluation: - mae - {:.4f} - mape - {:.4f} - rmse - {:.4f}'.format(mean_loss['mae'], mean_loss['mape'], mean_loss['rmse']))
self._writer.add_scalar('{} loss'.format(dataset), mean_loss['mae'], batches_seen)
if(save):
y_preds = np.concatenate(y_preds, axis=1)
y_truths = np.concatenate(y_truths, axis=1) # concatenate on batch dimension
y_truths_scaled = []
y_preds_scaled = []
# self._logger.debug("y_preds shape: {}, y_truth shape {}".format(y_preds.shape, y_truths.shape))
for t in range(y_preds.shape[0]):
y_truth = self.standard_scaler.inverse_transform(y_truths[t])
y_pred = self.standard_scaler.inverse_transform(y_preds[t])
y_truths_scaled.append(y_truth)
y_preds_scaled.append(y_pred)
y_preds_scaled = np.stack(y_preds_scaled)
y_truths_scaled = np.stack(y_truths_scaled)
y_dict = {'prediction': y_preds_scaled, 'truth': y_truths_scaled}
# save_dir = self._data_kwargs.get('dataset_dir', 'data')
# save_path = os.path.join(save_dir, 'pred.npz')
# np.savez(save_path, prediction=y_preds_scaled, turth=y_truths_scaled)
return mean_loss['mae'], y_dict
def eval_more(self, dataset='val', save=False, seq_len=[3, 6, 9, 12], extract_latent=False):
"""
Computes mae rmse mape loss and the prediction if `save` is set True.
"""
self._logger.info('Model mode: Evaluation')
with torch.no_grad():
self.stden_model.eval()
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
mae_losses = []
mape_losses = []
rmse_losses = []
if(save):
y_truths = []
y_preds = []
if(extract_latent):
latents = []
# used for nfe
c = []
for _, (x, y) in enumerate(val_iterator):
x, y = self._prepare_data(x, y)
output, fe = self.stden_model(x)
mae, mape, rmse = [], [], []
for seq in seq_len:
_mae, _mape, _rmse = self._compute_loss_eval(y[seq-1], output[seq-1])
mae.append(_mae)
mape.append(_mape)
rmse.append(_rmse)
mae_losses.append(mae)
mape_losses.append(mape)
rmse_losses.append(rmse)
c.append([*fe, np.mean(mae)])
if(save):
y_truths.append(y.cpu())
y_preds.append(output.cpu())
if(extract_latent):
latents.append(self.stden_model.latent_feat.cpu())
mean_loss = {
'mae': np.mean(mae_losses, axis=0),
'mape': np.mean(mape_losses, axis=0),
'rmse': np.mean(rmse_losses, axis=0)
}
for i, seq in enumerate(seq_len):
self._logger.info('Evaluation seq {}: - mae - {:.4f} - mape - {:.4f} - rmse - {:.4f}'.format(
seq, mean_loss['mae'][i], mean_loss['mape'][i], mean_loss['rmse'][i]))
if(save):
# shape (horizon, num_sapmles, feat_dim)
y_preds = np.concatenate(y_preds, axis=1)
y_truths = np.concatenate(y_truths, axis=1) # concatenate on batch dimension
y_preds_scaled = self.standard_scaler.inverse_transform(y_preds)
y_truths_scaled = self.standard_scaler.inverse_transform(y_truths)
save_dir = self._data_kwargs.get('dataset_dir', 'data')
save_path = os.path.join(save_dir, 'pred_{}_{}.npz'.format(self.experimentID, self._epoch_num))
np.savez_compressed(save_path, prediction=y_preds_scaled, turth=y_truths_scaled)
if(extract_latent):
# concatenate on batch dimension
latents = np.concatenate(latents, axis=1)
# Shape of latents (horizon, num_samples, self.num_edges * self.output_dim)
save_dir = self._data_kwargs.get('dataset_dir', 'data')
filter_type = self._model_kwargs.get('filter_type', 'unknown')
save_path = os.path.join(save_dir, '{}_latent_{}_{}.npz'.format(filter_type, self.experimentID, self._epoch_num))
np.savez_compressed(save_path, latent=latents)
if bool(self._model_kwargs.get('nfe', False)):
res = pd.DataFrame(c, columns=['nfe', 'time', 'err'])
res.index.name = 'iter'
filter_type = self._model_kwargs.get('filter_type', 'unknown')
atol = float(self._model_kwargs.get('odeint_atol', 1e-5))
rtol = float(self._model_kwargs.get('odeint_rtol', 1e-5))
nfe_file = os.path.join(
self._data_kwargs.get('dataset_dir', 'data'),
'nfe_{}_a{}_r{}.pkl'.format(filter_type, int(atol*1e5), int(rtol*1e5)))
res.to_pickle(nfe_file)