mirror of https://github.com/czzhangheng/STDEN.git
415 lines
17 KiB
Python
415 lines
17 KiB
Python
import os
|
|
import time
|
|
from random import SystemRandom
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from lib import utils
|
|
from model.stden_model import STDENModel
|
|
from lib.metrics import masked_mae_loss, masked_mape_loss, masked_mse_loss, masked_rmse_loss
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
class STDENSupervisor:
|
|
def __init__(self, adj_mx, **kwargs):
|
|
self._kwargs = kwargs
|
|
self._data_kwargs = kwargs.get('data')
|
|
self._model_kwargs = kwargs.get('model')
|
|
self._train_kwargs = kwargs.get('train')
|
|
|
|
self.max_grad_norm = self._train_kwargs.get('max_grad_norm', 1.)
|
|
|
|
# logging.
|
|
self._log_dir = utils.get_log_dir(kwargs)
|
|
self._writer = SummaryWriter('runs/' + self._log_dir)
|
|
|
|
log_level = self._kwargs.get('log_level', 'INFO')
|
|
self._logger = utils.get_logger(self._log_dir, __name__, 'info.log', level=log_level)
|
|
|
|
# data set
|
|
self._data = utils.load_dataset(**self._data_kwargs)
|
|
self.standard_scaler = self._data['scaler']
|
|
self._logger.info('Scaler mean: {:.6f}, std {:.6f}.'.format(self.standard_scaler.mean, self.standard_scaler.std))
|
|
|
|
self.num_edges = (adj_mx > 0.).sum()
|
|
self.input_dim = int(self._model_kwargs.get('input_dim', 1))
|
|
self.seq_len = int(self._model_kwargs.get('seq_len')) # for the encoder
|
|
self.output_dim = int(self._model_kwargs.get('output_dim', 1))
|
|
self.use_curriculum_learning = bool(
|
|
self._model_kwargs.get('use_curriculum_learning', False))
|
|
self.horizon = int(self._model_kwargs.get('horizon', 1)) # for the decoder
|
|
|
|
# setup model
|
|
stden_model = STDENModel(adj_mx, self._logger, **self._model_kwargs)
|
|
self.stden_model = stden_model.cuda() if torch.cuda.is_available() else stden_model
|
|
self._logger.info("Model created")
|
|
|
|
self.experimentID = self._train_kwargs.get('load', 0)
|
|
if self.experimentID == 0:
|
|
# Make a new experiment ID
|
|
self.experimentID = int(SystemRandom().random()*100000)
|
|
self.ckpt_path = os.path.join("ckpt/", "experiment_" + str(self.experimentID))
|
|
|
|
self._epoch_num = self._train_kwargs.get('epoch', 0)
|
|
if self._epoch_num > 0:
|
|
self._logger.info('Loading model...')
|
|
self.load_model()
|
|
|
|
def save_model(self, epoch):
|
|
model_dir = self.ckpt_path
|
|
if not os.path.exists(model_dir):
|
|
os.makedirs(model_dir)
|
|
|
|
config = dict(self._kwargs)
|
|
config['model_state_dict'] = self.stden_model.state_dict()
|
|
config['epoch'] = epoch
|
|
model_path = os.path.join(model_dir, 'epo{}.tar'.format(epoch))
|
|
torch.save(config, model_path)
|
|
self._logger.info("Saved model at {}".format(epoch))
|
|
return model_path
|
|
|
|
def load_model(self):
|
|
self._setup_graph()
|
|
model_path = os.path.join(self.ckpt_path, 'epo{}.tar'.format(self._epoch_num))
|
|
assert os.path.exists(model_path), 'Weights at epoch %d not found' % self._epoch_num
|
|
|
|
checkpoint = torch.load(model_path, map_location='cpu')
|
|
self.stden_model.load_state_dict(checkpoint['model_state_dict'])
|
|
self._logger.info("Loaded model at {}".format(self._epoch_num))
|
|
|
|
def _setup_graph(self):
|
|
with torch.no_grad():
|
|
self.stden_model.eval()
|
|
|
|
val_iterator = self._data['val_loader'].get_iterator()
|
|
|
|
for _, (x, y) in enumerate(val_iterator):
|
|
x, y = self._prepare_data(x, y)
|
|
output = self.stden_model(x)
|
|
break
|
|
|
|
def train(self, **kwargs):
|
|
self._logger.info('Model mode: train')
|
|
kwargs.update(self._train_kwargs)
|
|
return self._train(**kwargs)
|
|
|
|
def _train(self, base_lr,
|
|
steps, patience=50, epochs=100, lr_decay_ratio=0.1, log_every=1, save_model=1,
|
|
test_every_n_epochs=10, epsilon=1e-8, **kwargs):
|
|
# steps is used in learning rate - will see if need to use it?
|
|
min_val_loss = float('inf')
|
|
wait = 0
|
|
optimizer = torch.optim.Adam(self.stden_model.parameters(), lr=base_lr, eps=epsilon)
|
|
|
|
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=steps,
|
|
gamma=lr_decay_ratio)
|
|
|
|
self._logger.info('Start training ...')
|
|
|
|
# this will fail if model is loaded with a changed batch_size
|
|
num_batches = self._data['train_loader'].num_batch
|
|
self._logger.info("num_batches: {}".format(num_batches))
|
|
|
|
batches_seen = num_batches * self._epoch_num
|
|
|
|
# used for nfe
|
|
c = []
|
|
res, keys = [], []
|
|
|
|
for epoch_num in range(self._epoch_num, epochs):
|
|
|
|
self.stden_model.train()
|
|
|
|
train_iterator = self._data['train_loader'].get_iterator()
|
|
losses = []
|
|
|
|
start_time = time.time()
|
|
|
|
c.clear() #nfe
|
|
for i, (x, y) in enumerate(train_iterator):
|
|
if(i >= num_batches):
|
|
break
|
|
optimizer.zero_grad()
|
|
|
|
x, y = self._prepare_data(x, y)
|
|
|
|
output, fe = self.stden_model(x, y, batches_seen)
|
|
|
|
if batches_seen == 0:
|
|
# this is a workaround to accommodate dynamically registered parameters
|
|
optimizer = torch.optim.Adam(self.stden_model.parameters(), lr=base_lr, eps=epsilon)
|
|
|
|
loss = self._compute_loss(y, output)
|
|
self._logger.debug("FE: number - {}, time - {:.3f} s, err - {:.3f}".format(*fe, loss.item()))
|
|
c.append([*fe, loss.item()])
|
|
|
|
self._logger.debug(loss.item())
|
|
losses.append(loss.item())
|
|
|
|
batches_seen += 1 # global step in tensorboard
|
|
loss.backward()
|
|
|
|
# gradient clipping
|
|
torch.nn.utils.clip_grad_norm_(self.stden_model.parameters(), self.max_grad_norm)
|
|
|
|
optimizer.step()
|
|
|
|
del x, y, output, loss # del make these memory no-labeled trash
|
|
torch.cuda.empty_cache() # empty_cache() recycle no-labeled trash
|
|
|
|
# used for nfe
|
|
res.append(pd.DataFrame(c, columns=['nfe', 'time', 'err']))
|
|
keys.append(epoch_num)
|
|
|
|
self._logger.info("epoch complete")
|
|
lr_scheduler.step()
|
|
self._logger.info("evaluating now!")
|
|
|
|
val_loss, _ = self.evaluate(dataset='val', batches_seen=batches_seen)
|
|
|
|
end_time = time.time()
|
|
|
|
self._writer.add_scalar('training loss',
|
|
np.mean(losses),
|
|
batches_seen)
|
|
|
|
if (epoch_num % log_every) == log_every - 1:
|
|
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f}, lr: {:.6f}, ' \
|
|
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
|
|
np.mean(losses), val_loss, lr_scheduler.get_lr()[0],
|
|
(end_time - start_time))
|
|
self._logger.info(message)
|
|
|
|
if (epoch_num % test_every_n_epochs) == test_every_n_epochs - 1:
|
|
test_loss, _ = self.evaluate(dataset='test', batches_seen=batches_seen)
|
|
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f}, lr: {:.6f}, ' \
|
|
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
|
|
np.mean(losses), test_loss, lr_scheduler.get_lr()[0],
|
|
(end_time - start_time))
|
|
self._logger.info(message)
|
|
|
|
if val_loss < min_val_loss:
|
|
wait = 0
|
|
if save_model:
|
|
model_file_name = self.save_model(epoch_num)
|
|
self._logger.info(
|
|
'Val loss decrease from {:.4f} to {:.4f}, '
|
|
'saving to {}'.format(min_val_loss, val_loss, model_file_name))
|
|
min_val_loss = val_loss
|
|
|
|
elif val_loss >= min_val_loss:
|
|
wait += 1
|
|
if wait == patience:
|
|
self._logger.warning('Early stopping at epoch: %d' % epoch_num)
|
|
break
|
|
|
|
if bool(self._model_kwargs.get('nfe', False)):
|
|
res = pd.concat(res, keys=keys)
|
|
# self._logger.info("res.shape: ", res.shape)
|
|
res.index.names = ['epoch', 'iter']
|
|
filter_type = self._model_kwargs.get('filter_type', 'unknown')
|
|
atol = float(self._model_kwargs.get('odeint_atol', 1e-5))
|
|
rtol = float(self._model_kwargs.get('odeint_rtol', 1e-5))
|
|
nfe_file = os.path.join(
|
|
self._data_kwargs.get('dataset_dir', 'data'),
|
|
'nfe_{}_a{}_r{}.pkl'.format(filter_type, int(atol*1e5), int(rtol*1e5)))
|
|
res.to_pickle(nfe_file)
|
|
# res.to_csv(nfe_file)
|
|
|
|
def _prepare_data(self, x, y):
|
|
x, y = self._get_x_y(x, y)
|
|
x, y = self._get_x_y_in_correct_dims(x, y)
|
|
return x.to(device), y.to(device)
|
|
|
|
def _get_x_y(self, x, y):
|
|
"""
|
|
:param x: shape (batch_size, seq_len, num_edges, input_dim)
|
|
:param y: shape (batch_size, horizon, num_edges, input_dim)
|
|
:returns x shape (seq_len, batch_size, num_edges, input_dim)
|
|
y shape (horizon, batch_size, num_edges, input_dim)
|
|
"""
|
|
x = torch.from_numpy(x).float()
|
|
y = torch.from_numpy(y).float()
|
|
self._logger.debug("X: {}".format(x.size()))
|
|
self._logger.debug("y: {}".format(y.size()))
|
|
x = x.permute(1, 0, 2, 3)
|
|
y = y.permute(1, 0, 2, 3)
|
|
return x, y
|
|
|
|
def _get_x_y_in_correct_dims(self, x, y):
|
|
"""
|
|
:param x: shape (seq_len, batch_size, num_edges, input_dim)
|
|
:param y: shape (horizon, batch_size, num_edges, input_dim)
|
|
:return: x: shape (seq_len, batch_size, num_edges * input_dim)
|
|
y: shape (horizon, batch_size, num_edges * output_dim)
|
|
"""
|
|
batch_size = x.size(1)
|
|
self._logger.debug("size of x {}".format(x.size()))
|
|
x = x.view(self.seq_len, batch_size, self.num_edges * self.input_dim)
|
|
y = y[..., :self.output_dim].view(self.horizon, batch_size,
|
|
self.num_edges * self.output_dim)
|
|
return x, y
|
|
|
|
def _compute_loss(self, y_true, y_predicted):
|
|
y_true = self.standard_scaler.inverse_transform(y_true)
|
|
y_predicted = self.standard_scaler.inverse_transform(y_predicted)
|
|
return masked_mae_loss(y_predicted, y_true)
|
|
|
|
def _compute_loss_eval(self, y_true, y_predicted):
|
|
y_true = self.standard_scaler.inverse_transform(y_true)
|
|
y_predicted = self.standard_scaler.inverse_transform(y_predicted)
|
|
return masked_mae_loss(y_predicted, y_true).item(), masked_mape_loss(y_predicted, y_true).item(), masked_rmse_loss(y_predicted, y_true).item()
|
|
|
|
def evaluate(self, dataset='val', batches_seen=0, save=False):
|
|
"""
|
|
Computes mae rmse mape loss and the predict if save
|
|
:return: mean L1Loss
|
|
"""
|
|
with torch.no_grad():
|
|
self.stden_model.eval()
|
|
|
|
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
|
|
mae_losses = []
|
|
mape_losses = []
|
|
rmse_losses = []
|
|
y_dict = None
|
|
|
|
if(save):
|
|
y_truths = []
|
|
y_preds = []
|
|
|
|
for _, (x, y) in enumerate(val_iterator):
|
|
x, y = self._prepare_data(x, y)
|
|
|
|
output, fe = self.stden_model(x)
|
|
mae, mape, rmse = self._compute_loss_eval(y, output)
|
|
mae_losses.append(mae)
|
|
mape_losses.append(mape)
|
|
rmse_losses.append(rmse)
|
|
|
|
if(save):
|
|
y_truths.append(y.cpu())
|
|
y_preds.append(output.cpu())
|
|
|
|
mean_loss = {
|
|
'mae': np.mean(mae_losses),
|
|
'mape': np.mean(mape_losses),
|
|
'rmse': np.mean(rmse_losses)
|
|
}
|
|
|
|
self._logger.info('Evaluation: - mae - {:.4f} - mape - {:.4f} - rmse - {:.4f}'.format(mean_loss['mae'], mean_loss['mape'], mean_loss['rmse']))
|
|
self._writer.add_scalar('{} loss'.format(dataset), mean_loss['mae'], batches_seen)
|
|
|
|
if(save):
|
|
y_preds = np.concatenate(y_preds, axis=1)
|
|
y_truths = np.concatenate(y_truths, axis=1) # concatenate on batch dimension
|
|
|
|
y_truths_scaled = []
|
|
y_preds_scaled = []
|
|
# self._logger.debug("y_preds shape: {}, y_truth shape {}".format(y_preds.shape, y_truths.shape))
|
|
for t in range(y_preds.shape[0]):
|
|
y_truth = self.standard_scaler.inverse_transform(y_truths[t])
|
|
y_pred = self.standard_scaler.inverse_transform(y_preds[t])
|
|
y_truths_scaled.append(y_truth)
|
|
y_preds_scaled.append(y_pred)
|
|
|
|
y_preds_scaled = np.stack(y_preds_scaled)
|
|
y_truths_scaled = np.stack(y_truths_scaled)
|
|
|
|
y_dict = {'prediction': y_preds_scaled, 'truth': y_truths_scaled}
|
|
|
|
# save_dir = self._data_kwargs.get('dataset_dir', 'data')
|
|
# save_path = os.path.join(save_dir, 'pred.npz')
|
|
# np.savez(save_path, prediction=y_preds_scaled, turth=y_truths_scaled)
|
|
|
|
return mean_loss['mae'], y_dict
|
|
|
|
def eval_more(self, dataset='val', save=False, seq_len=[3, 6, 9, 12], extract_latent=False):
|
|
"""
|
|
Computes mae rmse mape loss and the prediction if `save` is set True.
|
|
"""
|
|
self._logger.info('Model mode: Evaluation')
|
|
with torch.no_grad():
|
|
self.stden_model.eval()
|
|
|
|
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
|
|
mae_losses = []
|
|
mape_losses = []
|
|
rmse_losses = []
|
|
|
|
if(save):
|
|
y_truths = []
|
|
y_preds = []
|
|
|
|
if(extract_latent):
|
|
latents = []
|
|
|
|
# used for nfe
|
|
c = []
|
|
for _, (x, y) in enumerate(val_iterator):
|
|
x, y = self._prepare_data(x, y)
|
|
|
|
output, fe = self.stden_model(x)
|
|
mae, mape, rmse = [], [], []
|
|
for seq in seq_len:
|
|
_mae, _mape, _rmse = self._compute_loss_eval(y[seq-1], output[seq-1])
|
|
mae.append(_mae)
|
|
mape.append(_mape)
|
|
rmse.append(_rmse)
|
|
mae_losses.append(mae)
|
|
mape_losses.append(mape)
|
|
rmse_losses.append(rmse)
|
|
c.append([*fe, np.mean(mae)])
|
|
|
|
if(save):
|
|
y_truths.append(y.cpu())
|
|
y_preds.append(output.cpu())
|
|
|
|
if(extract_latent):
|
|
latents.append(self.stden_model.latent_feat.cpu())
|
|
|
|
mean_loss = {
|
|
'mae': np.mean(mae_losses, axis=0),
|
|
'mape': np.mean(mape_losses, axis=0),
|
|
'rmse': np.mean(rmse_losses, axis=0)
|
|
}
|
|
|
|
for i, seq in enumerate(seq_len):
|
|
self._logger.info('Evaluation seq {}: - mae - {:.4f} - mape - {:.4f} - rmse - {:.4f}'.format(
|
|
seq, mean_loss['mae'][i], mean_loss['mape'][i], mean_loss['rmse'][i]))
|
|
|
|
if(save):
|
|
# shape (horizon, num_sapmles, feat_dim)
|
|
y_preds = np.concatenate(y_preds, axis=1)
|
|
y_truths = np.concatenate(y_truths, axis=1) # concatenate on batch dimension
|
|
y_preds_scaled = self.standard_scaler.inverse_transform(y_preds)
|
|
y_truths_scaled = self.standard_scaler.inverse_transform(y_truths)
|
|
|
|
save_dir = self._data_kwargs.get('dataset_dir', 'data')
|
|
save_path = os.path.join(save_dir, 'pred_{}_{}.npz'.format(self.experimentID, self._epoch_num))
|
|
np.savez_compressed(save_path, prediction=y_preds_scaled, turth=y_truths_scaled)
|
|
|
|
if(extract_latent):
|
|
# concatenate on batch dimension
|
|
latents = np.concatenate(latents, axis=1)
|
|
# Shape of latents (horizon, num_samples, self.num_edges * self.output_dim)
|
|
|
|
save_dir = self._data_kwargs.get('dataset_dir', 'data')
|
|
filter_type = self._model_kwargs.get('filter_type', 'unknown')
|
|
save_path = os.path.join(save_dir, '{}_latent_{}_{}.npz'.format(filter_type, self.experimentID, self._epoch_num))
|
|
np.savez_compressed(save_path, latent=latents)
|
|
|
|
if bool(self._model_kwargs.get('nfe', False)):
|
|
res = pd.DataFrame(c, columns=['nfe', 'time', 'err'])
|
|
res.index.name = 'iter'
|
|
filter_type = self._model_kwargs.get('filter_type', 'unknown')
|
|
atol = float(self._model_kwargs.get('odeint_atol', 1e-5))
|
|
rtol = float(self._model_kwargs.get('odeint_rtol', 1e-5))
|
|
nfe_file = os.path.join(
|
|
self._data_kwargs.get('dataset_dir', 'data'),
|
|
'nfe_{}_a{}_r{}.pkl'.format(filter_type, int(atol*1e5), int(rtol*1e5)))
|
|
res.to_pickle(nfe_file)
|
|
|