commit 3a5f5e517050e3235563eba821b5419dd0a92ff4 Author: Echo-Ji Date: Sun Nov 7 19:43:38 2021 +0800 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..81e973d --- /dev/null +++ b/.gitignore @@ -0,0 +1,109 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class +logs/ +runs/ +ckpt/ +data/ +.vscode/ +figures/ + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# dotenv +.env + +# virtualenv +.venv +venv/ +ENV/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# pycharm +.idea/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..ff4c0fb --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Echo Ji + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..bea7461 --- /dev/null +++ b/README.md @@ -0,0 +1,39 @@ +# STDEN + +This is pandaif.com in paper Towards Physics-guided Neural Networks for Traffic Flow Prediction. + +## Requirement + +* scipy>=1.5.2 +* numpy>=1.19.1 +* pandas>=1.1.5 +* pyyaml>=5.3.1 +* pytorch>=1.7.1 +* future>=0.18.2 +* torchdiffeq>=0.2.0 + +Dependency can be installed using the following command: + +``` +pip install -r requirements.txt +``` + +## Model Traning and Evaluation + +One can run the code by +```bash +# traning for dataset GT-221 +python stden_train.py --config_filename=configs/stden_gt.yaml + +# testing for dataset GT-221 +python stden_eval.py --config_filename=configs/stden_gt.yaml +``` +The configuration file of all datasets are as follows: + +|dataset|config file| +|:--|:--| +|GT-221|stden_gt.yaml| +|WRS-393|stden_wrs.yaml| +|ZGC-564|stden_zgc.yaml| + +PS: The data is not public and I am not allowed to distribute it. diff --git a/configs/stde_gt.yaml b/configs/stde_gt.yaml new file mode 100644 index 0000000..1b2ca18 --- /dev/null +++ b/configs/stde_gt.yaml @@ -0,0 +1,44 @@ +--- +log_base_dir: logs/BJ_GM +log_level: INFO + +data: + batch_size: 32 + dataset_dir: data/BJ_GM + val_batch_size: 32 + graph_pkl_filename: data/sensor_graph/adj_GM.npy + +model: + l1_decay: 0 + seq_len: 12 + horizon: 12 + input_dim: 1 + output_dim: 1 + latent_dim: 4 + n_traj_samples: 3 + ode_method: dopri5 + odeint_atol: 0.00001 + odeint_rtol: 0.00001 + rnn_units: 64 + num_rnn_layers: 1 + gcn_step: 2 + filter_type: default # unkP IncP default + recg_type: gru + save_latent: false + nfe: false + +train: + base_lr: 0.01 + dropout: 0 + load: 0 + epoch: 0 + epochs: 100 + epsilon: 1.0e-3 + lr_decay_ratio: 0.1 + max_grad_norm: 5 + min_learning_rate: 2.0e-06 + optimizer: adam + patience: 20 + steps: [20, 30, 40, 50] + test_every_n_epochs: 5 + \ No newline at end of file diff --git a/configs/stde_wrs.yaml b/configs/stde_wrs.yaml new file mode 100644 index 0000000..8402736 --- /dev/null +++ b/configs/stde_wrs.yaml @@ -0,0 +1,44 @@ +--- +log_base_dir: logs/BJ_RM +log_level: INFO + +data: + batch_size: 32 + dataset_dir: data/BJ_RM + val_batch_size: 32 + graph_pkl_filename: data/sensor_graph/adj_RM.npy + +model: + l1_decay: 0 + seq_len: 12 + horizon: 12 + input_dim: 1 + output_dim: 1 + latent_dim: 4 + n_traj_samples: 3 + ode_method: dopri5 + odeint_atol: 0.00001 + odeint_rtol: 0.00001 + rnn_units: 64 # for recognition + num_rnn_layers: 1 + gcn_step: 2 + filter_type: default # unkP IncP default + recg_type: gru + save_latent: false + nfe: false + +train: + base_lr: 0.01 + dropout: 0 + load: 0 # 0 for not load + epoch: 0 + epochs: 100 + epsilon: 1.0e-3 + lr_decay_ratio: 0.1 + max_grad_norm: 5 + min_learning_rate: 2.0e-06 + optimizer: adam + patience: 20 + steps: [20, 30, 40, 50] + test_every_n_epochs: 5 + \ No newline at end of file diff --git a/configs/stde_zgc.yaml b/configs/stde_zgc.yaml new file mode 100644 index 0000000..f291538 --- /dev/null +++ b/configs/stde_zgc.yaml @@ -0,0 +1,44 @@ +--- +log_base_dir: logs/BJ_XZ +log_level: INFO + +data: + batch_size: 32 + dataset_dir: data/BJ_XZ + val_batch_size: 32 + graph_pkl_filename: data/sensor_graph/adj_XZ.npy + +model: + l1_decay: 0 + seq_len: 12 + horizon: 12 + input_dim: 1 + output_dim: 1 + latent_dim: 4 + n_traj_samples: 3 + ode_method: dopri5 + odeint_atol: 0.00001 + odeint_rtol: 0.00001 + rnn_units: 64 + num_rnn_layers: 1 + gcn_step: 2 + filter_type: default # unkP IncP default + recg_type: gru + save_latent: false + nfe: false + +train: + base_lr: 0.01 + dropout: 0 + load: 0 # 0 for not load + epoch: 0 + epochs: 100 + epsilon: 1.0e-3 + lr_decay_ratio: 0.1 + max_grad_norm: 5 + min_learning_rate: 2.0e-06 + optimizer: adam + patience: 20 + steps: [20, 30, 40, 50] + test_every_n_epochs: 5 + \ No newline at end of file diff --git a/lib/__init__.py b/lib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lib/metrics.py b/lib/metrics.py new file mode 100644 index 0000000..117b0db --- /dev/null +++ b/lib/metrics.py @@ -0,0 +1,34 @@ +import torch + +def masked_mae_loss(y_pred, y_true): + # print('y_pred: ', y_pred.shape, 'y_true: ', y_true.shape) + y_true[y_true < 1e-4] = 0 + mask = (y_true != 0).float() + mask /= mask.mean() # 将0值的权重分配给非零值 + loss = torch.abs(y_pred - y_true) + loss = loss * mask + # trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3 + loss[loss != loss] = 0 + return loss.mean() + +def masked_mape_loss(y_pred, y_true): + # print('y_pred: ', y_pred.shape, 'y_true: ', y_true.shape) + y_true[y_true < 1e-4] = 0 + mask = (y_true != 0).float() + mask /= mask.mean() # 将0值的权重分配给非零值 + loss = torch.abs((y_pred - y_true) / y_true) + loss = loss * mask + # trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3 + loss[loss != loss] = 0 + return loss.mean() + +def masked_rmse_loss(y_pred, y_true): + y_true[y_true < 1e-4] = 0 + # print('y_pred: ', y_pred.shape, 'y_true: ', y_true.shape) + mask = (y_true != 0).float() + mask /= mask.mean() + loss = torch.pow(y_pred - y_true, 2) + loss = loss * mask + # trick for nans: https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/3 + loss[loss != loss] = 0 + return torch.sqrt(loss.mean()) diff --git a/lib/utils.py b/lib/utils.py new file mode 100644 index 0000000..dafd6ec --- /dev/null +++ b/lib/utils.py @@ -0,0 +1,233 @@ +import logging +import numpy as np +import os +import time +import pickle +import scipy.sparse as sp +import sys +# import tensorflow as tf +import torch +import torch.nn as nn + +from scipy.sparse import linalg + + +class DataLoader(object): + def __init__(self, xs, ys, batch_size, pad_with_last_sample=True, shuffle=False): + """ + + :param xs: + :param ys: + :param batch_size: + :param pad_with_last_sample: pad with the last sample to make number of samples divisible to batch_size. + """ + self.batch_size = batch_size + self.current_ind = 0 + if pad_with_last_sample: + num_padding = (batch_size - (len(xs) % batch_size)) % batch_size + x_padding = np.repeat(xs[-1:], num_padding, axis=0) + y_padding = np.repeat(ys[-1:], num_padding, axis=0) + xs = np.concatenate([xs, x_padding], axis=0) + ys = np.concatenate([ys, y_padding], axis=0) + self.size = len(xs) + self.num_batch = int(self.size // self.batch_size) + if shuffle: + permutation = np.random.permutation(self.size) + xs, ys = xs[permutation], ys[permutation] + self.xs = xs + self.ys = ys + + def get_iterator(self): + self.current_ind = 0 + + def _wrapper(): + while self.current_ind < self.num_batch: + start_ind = self.batch_size * self.current_ind + end_ind = min(self.size, self.batch_size * (self.current_ind + 1)) + x_i = self.xs[start_ind: end_ind, ...] + y_i = self.ys[start_ind: end_ind, ...] + yield (x_i, y_i) + self.current_ind += 1 + + return _wrapper() + + +class StandardScaler: + """ + Standard the input + """ + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def transform(self, data): + return (data - self.mean) / self.std + + def inverse_transform(self, data): + return (data * self.std) + self.mean + + +def calculate_random_walk_matrix(adj_mx): + adj_mx = sp.coo_matrix(adj_mx) + d = np.array(adj_mx.sum(1)) + d_inv = np.power(d, -1).flatten() + d_inv[np.isinf(d_inv)] = 0. + d_mat_inv = sp.diags(d_inv) + random_walk_mx = d_mat_inv.dot(adj_mx).tocoo() + return random_walk_mx + +def config_logging(log_dir, log_filename='info.log', level=logging.INFO): + # Add file handler and stdout handler + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + # Create the log directory if necessary. + try: + os.makedirs(log_dir) + except OSError: + pass + file_handler = logging.FileHandler(os.path.join(log_dir, log_filename)) + file_handler.setFormatter(formatter) + file_handler.setLevel(level=level) + # Add console handler. + console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(console_formatter) + console_handler.setLevel(level=level) + logging.basicConfig(handlers=[file_handler, console_handler], level=level) + + +def get_logger(log_dir, name, log_filename='info.log', level=logging.INFO): + logger = logging.getLogger(name) + logger.setLevel(level) + # Add file handler and stdout handler + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + file_handler = logging.FileHandler(os.path.join(log_dir, log_filename)) + file_handler.setFormatter(formatter) + # Add console handler. + console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(console_formatter) + logger.addHandler(file_handler) + logger.addHandler(console_handler) + # Add google cloud log handler + logger.info('Log directory: %s', log_dir) + return logger + + +def get_log_dir(kwargs): + log_dir = kwargs['train'].get('log_dir') + if log_dir is None: + batch_size = kwargs['data'].get('batch_size') + + filter_type = kwargs['model'].get('filter_type') + gcn_step = kwargs['model'].get('gcn_step') + horizon = kwargs['model'].get('horizon') + latent_dim = kwargs['model'].get('latent_dim') + n_traj_samples = kwargs['model'].get('n_traj_samples') + ode_method = kwargs['model'].get('ode_method') + + seq_len = kwargs['model'].get('seq_len') + rnn_units = kwargs['model'].get('rnn_units') + recg_type = kwargs['model'].get('recg_type') + + if filter_type == 'unkP': + filter_type_abbr = 'UP' + elif filter_type == 'IncP': + filter_type_abbr = 'NV' + else: + filter_type_abbr = 'DF' + + + run_id = 'STDEN_%s-%d_%s-%d_L-%d_N-%d_M-%s_bs-%d_%d-%d_%s/' % ( + recg_type, rnn_units, filter_type_abbr, gcn_step, latent_dim, n_traj_samples, ode_method, batch_size, seq_len, horizon, time.strftime('%m%d%H%M%S')) + base_dir = kwargs.get('log_base_dir') + log_dir = os.path.join(base_dir, run_id) + if not os.path.exists(log_dir): + os.makedirs(log_dir) + return log_dir + + +def load_dataset(dataset_dir, batch_size, val_batch_size=None, **kwargs): + if('BJ' in dataset_dir): + data = dict(np.load(os.path.join(dataset_dir, 'flow.npz'))) # convert readonly NpzFile to writable dict Object + for category in ['train', 'val', 'test']: + data['x_' + category] = data['x_' + category] #[..., :4] # ignore the time index + else: + data = {} + for category in ['train', 'val', 'test']: + cat_data = np.load(os.path.join(dataset_dir, category + '.npz')) + data['x_' + category] = cat_data['x'] + data['y_' + category] = cat_data['y'] + scaler = StandardScaler(mean=data['x_train'].mean(), std=data['x_train'].std()) # 第0维是要预测的量,但是第1维是什么呢? + # Data format + for category in ['train', 'val', 'test']: + data['x_' + category] = scaler.transform(data['x_' + category]) + data['y_' + category] = scaler.transform(data['y_' + category]) + data['train_loader'] = DataLoader(data['x_train'], data['y_train'], batch_size, shuffle=True) + data['val_loader'] = DataLoader(data['x_val'], data['y_val'], val_batch_size, shuffle=False) + data['test_loader'] = DataLoader(data['x_test'], data['y_test'], val_batch_size, shuffle=False) + data['scaler'] = scaler + + return data + + +def load_graph_data(pkl_filename): + adj_mx = np.load(pkl_filename) + return adj_mx + +def graph_grad(adj_mx): + """Fetch the graph gradient operator.""" + num_nodes = adj_mx.shape[0] + + num_edges = (adj_mx > 0.).sum() + grad = torch.zeros(num_nodes, num_edges) + e = 0 + for i in range(num_nodes): + for j in range(num_nodes): + if adj_mx[i, j] == 0: + continue + + grad[i, e] = 1. + grad[j, e] = -1. + e += 1 + return grad + +def init_network_weights(net, std = 0.1): + """ + Just for nn.Linear net. + """ + for m in net.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, mean=0, std=std) + nn.init.constant_(m.bias, val=0) + +def split_last_dim(data): + last_dim = data.size()[-1] + last_dim = last_dim//2 + + res = data[..., :last_dim], data[..., last_dim:] + return res + +def get_device(tensor): + device = torch.device("cpu") + if tensor.is_cuda: + device = tensor.get_device() + return device + +def sample_standard_gaussian(mu, sigma): + device = get_device(mu) + + d = torch.distributions.normal.Normal(torch.Tensor([0.]).to(device), torch.Tensor([1.]).to(device)) + r = d.sample(mu.size()).squeeze(-1) + return r * sigma.float() + mu.float() + +def create_net(n_inputs, n_outputs, n_layers = 0, + n_units = 100, nonlinear = nn.Tanh): + layers = [nn.Linear(n_inputs, n_units)] + for i in range(n_layers): + layers.append(nonlinear()) + layers.append(nn.Linear(n_units, n_units)) + + layers.append(nonlinear()) + layers.append(nn.Linear(n_units, n_outputs)) + return nn.Sequential(*layers) \ No newline at end of file diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/diffeq_solver.py b/model/diffeq_solver.py new file mode 100644 index 0000000..dfd0b17 --- /dev/null +++ b/model/diffeq_solver.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn +import time + +from torchdiffeq import odeint + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +class DiffeqSolver(nn.Module): + def __init__(self, odefunc, method, latent_dim, + odeint_rtol = 1e-4, odeint_atol = 1e-5): + nn.Module.__init__(self) + + self.ode_method = method + self.odefunc = odefunc + self.latent_dim = latent_dim + + self.rtol = odeint_rtol + self.atol = odeint_atol + + def forward(self, first_point, time_steps_to_pred): + """ + Decoder the trajectory through the ODE Solver. + + :param time_steps_to_pred: horizon + :param first_point: (n_traj_samples, batch_size, num_nodes * latent_dim) + :return: pred_y: # shape (horizon, n_traj_samples, batch_size, self.num_nodes * self.output_dim) + """ + n_traj_samples, batch_size = first_point.size()[0], first_point.size()[1] + first_point = first_point.reshape(n_traj_samples * batch_size, -1) # reduce the complexity by merging dimension + + # pred_y shape: (horizon, n_traj_samples * batch_size, num_nodes * latent_dim) + start_time = time.time() + self.odefunc.nfe = 0 + pred_y = odeint(self.odefunc, + first_point, + time_steps_to_pred, + rtol=self.rtol, + atol=self.atol, + method=self.ode_method) + time_fe = time.time() - start_time + + # pred_y shape: (horizon, n_traj_samples, batch_size, num_nodes * latent_dim) + pred_y = pred_y.reshape(pred_y.size()[0], n_traj_samples, batch_size, -1) + # assert(pred_y.size()[1] == n_traj_samples) + # assert(pred_y.size()[2] == batch_size) + + return pred_y, (self.odefunc.nfe, time_fe) + \ No newline at end of file diff --git a/model/ode_func.py b/model/ode_func.py new file mode 100644 index 0000000..a795e7f --- /dev/null +++ b/model/ode_func.py @@ -0,0 +1,165 @@ +import numpy as np +import torch +import torch.nn as nn + +from lib import utils + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +class LayerParams: + def __init__(self, rnn_network: nn.Module, layer_type: str): + self._rnn_network = rnn_network + self._params_dict = {} + self._biases_dict = {} + self._type = layer_type + + def get_weights(self, shape): + if shape not in self._params_dict: + nn_param = nn.Parameter(torch.empty(*shape, device=device)) + nn.init.xavier_normal_(nn_param) + self._params_dict[shape] = nn_param + self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)), + nn_param) + return self._params_dict[shape] + + def get_biases(self, length, bias_start=0.0): + if length not in self._biases_dict: + biases = nn.Parameter(torch.empty(length, device=device)) + nn.init.constant_(biases, bias_start) + self._biases_dict[length] = biases + self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)), + biases) + + return self._biases_dict[length] + +class ODEFunc(nn.Module): + def __init__(self, num_units, latent_dim, adj_mx, gcn_step, num_nodes, + gen_layers=1, nonlinearity='tanh', filter_type="default"): + """ + :param num_units: dimensionality of the hidden layers + :param latent_dim: dimensionality used for ODE (input and output). Analog of a continous latent state + :param adj_mx: + :param gcn_step: + :param num_nodes: + :param gen_layers: hidden layers in each ode func. + :param nonlinearity: + :param filter_type: default + :param use_gc_for_ru: whether to use Graph convolution to calculate the reset and update gates. + """ + super(ODEFunc, self).__init__() + self._activation = torch.tanh if nonlinearity == 'tanh' else torch.relu + + self._num_nodes = num_nodes + self._num_units = num_units # hidden dimension + self._latent_dim = latent_dim + self._gen_layers = gen_layers + self.nfe = 0 + + self._filter_type = filter_type + if(self._filter_type == "unkP"): + ode_func_net = utils.create_net(latent_dim, latent_dim, n_units=num_units) + utils.init_network_weights(ode_func_net) + self.gradient_net = ode_func_net + else: + self._gcn_step = gcn_step + self._gconv_params = LayerParams(self, 'gconv') + self._supports = [] + supports = [] + supports.append(utils.calculate_random_walk_matrix(adj_mx).T) + supports.append(utils.calculate_random_walk_matrix(adj_mx.T).T) + + for support in supports: + self._supports.append(self._build_sparse_matrix(support)) + + @staticmethod + def _build_sparse_matrix(L): + L = L.tocoo() + indices = np.column_stack((L.row, L.col)) + # this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L) + indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))] + L = torch.sparse_coo_tensor(indices.T, L.data, L.shape, device=device) + return L + + def forward(self, t_local, y, backwards = False): + """ + Perform one step in solving ODE. Given current data point y and current time point t_local, returns gradient dy/dt at this time point + + t_local: current time point + y: value at the current time point, shape (B, num_nodes * latent_dim) + + :return + - Output: A `2-D` tensor with shape `(B, num_nodes * latent_dim)`. + """ + self.nfe += 1 + grad = self.get_ode_gradient_nn(t_local, y) + if backwards: + grad = -grad + return grad + + def get_ode_gradient_nn(self, t_local, inputs): + if(self._filter_type == "unkP"): + grad = self._fc(inputs) + elif (self._filter_type == "IncP"): + grad = - self.ode_func_net(inputs) + else: # default is diffusion process + # theta shape: (B, num_nodes * latent_dim) + theta = torch.sigmoid(self._gconv(inputs, self._latent_dim, bias_start=1.0)) + grad = - theta * self.ode_func_net(inputs) + return grad + + def ode_func_net(self, inputs): + c = inputs + for i in range(self._gen_layers): + c = self._gconv(c, self._num_units) + c = self._activation(c) + c = self._gconv(c, self._latent_dim) + c = self._activation(c) + return c + + def _fc(self, inputs): + batch_size = inputs.size()[0] + grad = self.gradient_net(inputs.view(batch_size * self._num_nodes, self._latent_dim)) + return grad.reshape(batch_size, self._num_nodes * self._latent_dim) # (batch_size, num_nodes, latent_dim) + + @staticmethod + def _concat(x, x_): + x_ = x_.unsqueeze(0) + return torch.cat([x, x_], dim=0) + + def _gconv(self, inputs, output_size, bias_start=0.0): + # Reshape input and state to (batch_size, num_nodes, input_dim/state_dim) + batch_size = inputs.shape[0] + inputs = torch.reshape(inputs, (batch_size, self._num_nodes, -1)) + # state = torch.reshape(state, (batch_size, self._num_nodes, -1)) + # inputs_and_state = torch.cat([inputs, state], dim=2) + input_size = inputs.size(2) + + x = inputs + x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size) + x0 = torch.reshape(x0, shape=[self._num_nodes, input_size * batch_size]) + x = torch.unsqueeze(x0, 0) + + if self._gcn_step == 0: + pass + else: + for support in self._supports: + x1 = torch.sparse.mm(support, x0) + x = self._concat(x, x1) + + for k in range(2, self._gcn_step + 1): + x2 = 2 * torch.sparse.mm(support, x1) - x0 + x = self._concat(x, x2) + x1, x0 = x2, x1 + + num_matrices = len(self._supports) * self._gcn_step + 1 # Adds for x itself. + x = torch.reshape(x, shape=[num_matrices, self._num_nodes, input_size, batch_size]) + x = x.permute(3, 1, 2, 0) # (batch_size, num_nodes, input_size, order) + x = torch.reshape(x, shape=[batch_size * self._num_nodes, input_size * num_matrices]) + + weights = self._gconv_params.get_weights((input_size * num_matrices, output_size)) + x = torch.matmul(x, weights) # (batch_size * self._num_nodes, output_size) + + biases = self._gconv_params.get_biases(output_size, bias_start) + x += biases + # Reshape res back to 2D: (batch_size, num_node, state_dim) -> (batch_size, num_node * state_dim) + return torch.reshape(x, [batch_size, self._num_nodes * output_size]) diff --git a/model/stden_model.py b/model/stden_model.py new file mode 100644 index 0000000..97253ed --- /dev/null +++ b/model/stden_model.py @@ -0,0 +1,206 @@ +import time + +import torch +import torch.nn as nn + +from torch.nn.modules.rnn import GRU +from model.ode_func import ODEFunc +from model.diffeq_solver import DiffeqSolver + +from lib import utils + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + +class EncoderAttrs: + def __init__(self, adj_mx, **model_kwargs): + self.adj_mx = adj_mx + self.num_nodes = adj_mx.shape[0] + self.num_edges = (adj_mx > 0.).sum() + self.gcn_step = int(model_kwargs.get('gcn_step', 2)) + self.filter_type = model_kwargs.get('filter_type', 'default') + self.num_rnn_layers = int(model_kwargs.get('num_rnn_layers', 1)) + self.rnn_units = int(model_kwargs.get('rnn_units')) + self.latent_dim = int(model_kwargs.get('latent_dim', 4)) + +class STDENModel(nn.Module, EncoderAttrs): + def __init__(self, adj_mx, logger, **model_kwargs): + nn.Module.__init__(self) + EncoderAttrs.__init__(self, adj_mx, **model_kwargs) + self._logger = logger + #################################################### + # recognition net + #################################################### + self.encoder_z0 = Encoder_z0_RNN(adj_mx, **model_kwargs) + + #################################################### + # ode solver + #################################################### + self.n_traj_samples = int(model_kwargs.get('n_traj_samples', 1)) + self.ode_method = model_kwargs.get('ode_method', 'dopri5') + self.atol = float(model_kwargs.get('odeint_atol', 1e-4)) + self.rtol = float(model_kwargs.get('odeint_rtol', 1e-3)) + self.num_gen_layer = int(model_kwargs.get('gen_layers', 1)) + self.ode_gen_dim = int(model_kwargs.get('gen_dim', 64)) + ode_set_str = "ODE setting --latent {} --samples {} --method {} \ + --atol {:6f} --rtol {:6f} --gen_layer {} --gen_dim {}".format(\ + self.latent_dim, self.n_traj_samples, self.ode_method, \ + self.atol, self.rtol, self.num_gen_layer, self.ode_gen_dim) + odefunc = ODEFunc(self.ode_gen_dim, # hidden dimension + self.latent_dim, + adj_mx, + self.gcn_step, + self.num_nodes, + filter_type=self.filter_type + ).to(device) + self.diffeq_solver = DiffeqSolver(odefunc, + self.ode_method, + self.latent_dim, + odeint_rtol=self.rtol, + odeint_atol=self.atol + ) + self._logger.info(ode_set_str) + + self.save_latent = bool(model_kwargs.get('save_latent', False)) + self.latent_feat = None # used to extract the latent feature + + #################################################### + # decoder + #################################################### + self.horizon = int(model_kwargs.get('horizon', 1)) + self.out_feat = int(model_kwargs.get('output_dim', 1)) + self.decoder = Decoder( + self.out_feat, + adj_mx, + self.num_nodes, + self.num_edges, + ).to(device) + + ########################################## + def forward(self, inputs, labels=None, batches_seen=None): + """ + seq2seq forward pass + :param inputs: shape (seq_len, batch_size, num_edges * input_dim) + :param labels: shape (horizon, batch_size, num_edges * output_dim) + :param batches_seen: batches seen till now + :return: outputs: (self.horizon, batch_size, self.num_edges * self.output_dim) + """ + perf_time = time.time() + # shape: [1, batch, num_nodes * latent_dim] + first_point_mu, first_point_std = self.encoder_z0(inputs) + self._logger.debug("Recognition complete with {:.1f}s".format(time.time() - perf_time)) + + # sample 'n_traj_samples' trajectory + perf_time = time.time() + means_z0 = first_point_mu.repeat(self.n_traj_samples, 1, 1) + sigma_z0 = first_point_std.repeat(self.n_traj_samples, 1, 1) + first_point_enc = utils.sample_standard_gaussian(means_z0, sigma_z0) + + time_steps_to_predict = torch.arange(start=0, end=self.horizon, step=1).float().to(device) + time_steps_to_predict = time_steps_to_predict / len(time_steps_to_predict) + + # Shape of sol_ys (horizon, n_traj_samples, batch_size, self.num_nodes * self.latent_dim) + sol_ys, fe = self.diffeq_solver(first_point_enc, time_steps_to_predict) + self._logger.debug("ODE solver complete with {:.1f}s".format(time.time() - perf_time)) + if(self.save_latent): + # Shape of latent_feat (horizon, batch_size, self.num_nodes * self.latent_dim) + self.latent_feat = torch.mean(sol_ys.detach(), axis=1) + + perf_time = time.time() + outputs = self.decoder(sol_ys) + self._logger.debug("Decoder complete with {:.1f}s".format(time.time() - perf_time)) + + if batches_seen == 0: + self._logger.info( + "Total trainable parameters {}".format(count_parameters(self)) + ) + return outputs, fe + +class Encoder_z0_RNN(nn.Module, EncoderAttrs): + def __init__(self, adj_mx, **model_kwargs): + nn.Module.__init__(self) + EncoderAttrs.__init__(self, adj_mx, **model_kwargs) + self.recg_type = model_kwargs.get('recg_type', 'gru') # gru + + if(self.recg_type == 'gru'): + # gru settings + self.input_dim = int(model_kwargs.get('input_dim', 1)) + self.gru_rnn = GRU(self.input_dim, self.rnn_units).to(device) + else: + raise NotImplementedError("The recognition net only support 'gru'.") + + # hidden to z0 settings + self.inv_grad = utils.graph_grad(adj_mx).transpose(-2, -1) + self.inv_grad[self.inv_grad != 0.] = 0.5 + self.hiddens_to_z0 = nn.Sequential( + nn.Linear(self.rnn_units, 50), + nn.Tanh(), + nn.Linear(50, self.latent_dim * 2),) + + utils.init_network_weights(self.hiddens_to_z0) + + def forward(self, inputs): + """ + encoder forward pass on t time steps + :param inputs: shape (seq_len, batch_size, num_edges * input_dim) + :return: mean, std: # shape (n_samples=1, batch_size, self.latent_dim) + """ + if(self.recg_type == 'gru'): + # shape of outputs: (seq_len, batch, num_senor * rnn_units) + seq_len, batch_size = inputs.size(0), inputs.size(1) + inputs = inputs.reshape(seq_len, batch_size, self.num_edges, self.input_dim) + inputs = inputs.reshape(seq_len, batch_size * self.num_edges, self.input_dim) + + outputs, _ = self.gru_rnn(inputs) + last_output = outputs[-1] + # (batch_size, num_edges, rnn_units) + last_output = torch.reshape(last_output, (batch_size, self.num_edges, -1)) + last_output = torch.transpose(last_output, (-2, -1)) + # (batch_size, num_nodes, rnn_units) + last_output = torch.matmul(last_output, self.inv_grad).transpose(-2, -1) + else: + raise NotImplementedError("The recognition net only support 'gru'.") + + mean, std = utils.split_last_dim(self.hiddens_to_z0(last_output)) + mean = mean.reshape(batch_size, -1) # (batch_size, num_nodes * latent_dim) + std = std.reshape(batch_size, -1) # (batch_size, num_nodes * latent_dim) + std = std.abs() + + assert(not torch.isnan(mean).any()) + assert(not torch.isnan(std).any()) + + return mean.unsqueeze(0), std.unsqueeze(0) # for n_sample traj + +class Decoder(nn.Module): + def __init__(self, output_dim, adj_mx, num_nodes, num_edges): + super(Decoder, self).__init__() + + self.num_nodes = num_nodes + self.num_edges = num_edges + self.grap_grad = utils.graph_grad(adj_mx) + + self.output_dim = output_dim + + def forward(self, inputs): + """ + :param inputs: (horizon, n_traj_samples, batch_size, num_nodes * latent_dim) + :return outputs: (horizon, batch_size, num_edges * output_dim), average result of n_traj_samples. + """ + assert(len(inputs.size()) == 4) + horizon, n_traj_samples, batch_size = inputs.size()[:3] + + inputs = inputs.reshape(horizon, n_traj_samples, batch_size, self.num_nodes, -1).transpose(-2, -1) + latent_dim = inputs.size(-2) + # transform z with shape `(..., num_nodes)` to f with shape `(..., num_edges)`. + outputs = torch.matmul(inputs, self.grap_grad) + + outputs = outputs.reshape(horizon, n_traj_samples, batch_size, latent_dim, self.num_edges, self.output_dim) + outputs = torch.mean( + torch.mean(outputs, axis=3), + axis=1 + ) + outputs = outputs.reshape(horizon, batch_size, -1) + return outputs + diff --git a/model/stden_supervisor.py b/model/stden_supervisor.py new file mode 100644 index 0000000..f464010 --- /dev/null +++ b/model/stden_supervisor.py @@ -0,0 +1,415 @@ +import os +import time +from random import SystemRandom + +import numpy as np +import pandas as pd +import torch +from torch.utils.tensorboard import SummaryWriter + +from lib import utils +from model.stden_model import STDENModel +from lib.metrics import masked_mae_loss, masked_mape_loss, masked_mse_loss, masked_rmse_loss + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +class STDENSupervisor: + def __init__(self, adj_mx, **kwargs): + self._kwargs = kwargs + self._data_kwargs = kwargs.get('data') + self._model_kwargs = kwargs.get('model') + self._train_kwargs = kwargs.get('train') + + self.max_grad_norm = self._train_kwargs.get('max_grad_norm', 1.) + + # logging. + self._log_dir = utils.get_log_dir(kwargs) + self._writer = SummaryWriter('runs/' + self._log_dir) + + log_level = self._kwargs.get('log_level', 'INFO') + self._logger = utils.get_logger(self._log_dir, __name__, 'info.log', level=log_level) + + # data set + self._data = utils.load_dataset(**self._data_kwargs) + self.standard_scaler = self._data['scaler'] + self._logger.info('Scaler mean: {:.6f}, std {:.6f}.'.format(self.standard_scaler.mean, self.standard_scaler.std)) + + self.num_edges = (adj_mx > 0.).sum() + self.input_dim = int(self._model_kwargs.get('input_dim', 1)) + self.seq_len = int(self._model_kwargs.get('seq_len')) # for the encoder + self.output_dim = int(self._model_kwargs.get('output_dim', 1)) + self.use_curriculum_learning = bool( + self._model_kwargs.get('use_curriculum_learning', False)) + self.horizon = int(self._model_kwargs.get('horizon', 1)) # for the decoder + + # setup model + stden_model = STDENModel(adj_mx, self._logger, **self._model_kwargs) + self.stden_model = stden_model.cuda() if torch.cuda.is_available() else stden_model + self._logger.info("Model created") + + self.experimentID = self._train_kwargs.get('load', 0) + if self.experimentID == 0: + # Make a new experiment ID + self.experimentID = int(SystemRandom().random()*100000) + self.ckpt_path = os.path.join("ckpt/", "experiment_" + str(self.experimentID)) + + self._epoch_num = self._train_kwargs.get('epoch', 0) + if self._epoch_num > 0: + self._logger.info('Loading model...') + self.load_model() + + def save_model(self, epoch): + model_dir = self.ckpt_path + if not os.path.exists(model_dir): + os.makedirs(model_dir) + + config = dict(self._kwargs) + config['model_state_dict'] = self.stden_model.state_dict() + config['epoch'] = epoch + model_path = os.path.join(model_dir, 'epo{}.tar'.format(epoch)) + torch.save(config, model_path) + self._logger.info("Saved model at {}".format(epoch)) + return model_path + + def load_model(self): + self._setup_graph() + model_path = os.path.join(self.ckpt_path, 'epo{}.tar'.format(self._epoch_num)) + assert os.path.exists(model_path), 'Weights at epoch %d not found' % self._epoch_num + + checkpoint = torch.load(model_path, map_location='cpu') + self.stden_model.load_state_dict(checkpoint['model_state_dict']) + self._logger.info("Loaded model at {}".format(self._epoch_num)) + + def _setup_graph(self): + with torch.no_grad(): + self.stden_model.eval() + + val_iterator = self._data['val_loader'].get_iterator() + + for _, (x, y) in enumerate(val_iterator): + x, y = self._prepare_data(x, y) + output = self.stden_model(x) + break + + def train(self, **kwargs): + self._logger.info('Model mode: train') + kwargs.update(self._train_kwargs) + return self._train(**kwargs) + + def _train(self, base_lr, + steps, patience=50, epochs=100, lr_decay_ratio=0.1, log_every=1, save_model=1, + test_every_n_epochs=10, epsilon=1e-8, **kwargs): + # steps is used in learning rate - will see if need to use it? + min_val_loss = float('inf') + wait = 0 + optimizer = torch.optim.Adam(self.stden_model.parameters(), lr=base_lr, eps=epsilon) + + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=steps, + gamma=lr_decay_ratio) + + self._logger.info('Start training ...') + + # this will fail if model is loaded with a changed batch_size + num_batches = self._data['train_loader'].num_batch + self._logger.info("num_batches: {}".format(num_batches)) + + batches_seen = num_batches * self._epoch_num + + # used for nfe + c = [] + res, keys = [], [] + + for epoch_num in range(self._epoch_num, epochs): + + self.stden_model.train() + + train_iterator = self._data['train_loader'].get_iterator() + losses = [] + + start_time = time.time() + + c.clear() #nfe + for i, (x, y) in enumerate(train_iterator): + if(i >= num_batches): + break + optimizer.zero_grad() + + x, y = self._prepare_data(x, y) + + output, fe = self.stden_model(x, y, batches_seen) + + if batches_seen == 0: + # this is a workaround to accommodate dynamically registered parameters + optimizer = torch.optim.Adam(self.stden_model.parameters(), lr=base_lr, eps=epsilon) + + loss = self._compute_loss(y, output) + self._logger.debug("FE: number - {}, time - {:.3f} s, err - {:.3f}".format(*fe, loss.item())) + c.append([*fe, loss.item()]) + + self._logger.debug(loss.item()) + losses.append(loss.item()) + + batches_seen += 1 # global step in tensorboard + loss.backward() + + # gradient clipping + torch.nn.utils.clip_grad_norm_(self.stden_model.parameters(), self.max_grad_norm) + + optimizer.step() + + del x, y, output, loss # del make these memory no-labeled trash + torch.cuda.empty_cache() # empty_cache() recycle no-labeled trash + + # used for nfe + res.append(pd.DataFrame(c, columns=['nfe', 'time', 'err'])) + keys.append(epoch_num) + + self._logger.info("epoch complete") + lr_scheduler.step() + self._logger.info("evaluating now!") + + val_loss, _ = self.evaluate(dataset='val', batches_seen=batches_seen) + + end_time = time.time() + + self._writer.add_scalar('training loss', + np.mean(losses), + batches_seen) + + if (epoch_num % log_every) == log_every - 1: + message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f}, lr: {:.6f}, ' \ + '{:.1f}s'.format(epoch_num, epochs, batches_seen, + np.mean(losses), val_loss, lr_scheduler.get_lr()[0], + (end_time - start_time)) + self._logger.info(message) + + if (epoch_num % test_every_n_epochs) == test_every_n_epochs - 1: + test_loss, _ = self.evaluate(dataset='test', batches_seen=batches_seen) + message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f}, lr: {:.6f}, ' \ + '{:.1f}s'.format(epoch_num, epochs, batches_seen, + np.mean(losses), test_loss, lr_scheduler.get_lr()[0], + (end_time - start_time)) + self._logger.info(message) + + if val_loss < min_val_loss: + wait = 0 + if save_model: + model_file_name = self.save_model(epoch_num) + self._logger.info( + 'Val loss decrease from {:.4f} to {:.4f}, ' + 'saving to {}'.format(min_val_loss, val_loss, model_file_name)) + min_val_loss = val_loss + + elif val_loss >= min_val_loss: + wait += 1 + if wait == patience: + self._logger.warning('Early stopping at epoch: %d' % epoch_num) + break + + if bool(self._model_kwargs.get('nfe', False)): + res = pd.concat(res, keys=keys) + # self._logger.info("res.shape: ", res.shape) + res.index.names = ['epoch', 'iter'] + filter_type = self._model_kwargs.get('filter_type', 'unknown') + atol = float(self._model_kwargs.get('odeint_atol', 1e-5)) + rtol = float(self._model_kwargs.get('odeint_rtol', 1e-5)) + nfe_file = os.path.join( + self._data_kwargs.get('dataset_dir', 'data'), + 'nfe_{}_a{}_r{}.pkl'.format(filter_type, int(atol*1e5), int(rtol*1e5))) + res.to_pickle(nfe_file) + # res.to_csv(nfe_file) + + def _prepare_data(self, x, y): + x, y = self._get_x_y(x, y) + x, y = self._get_x_y_in_correct_dims(x, y) + return x.to(device), y.to(device) + + def _get_x_y(self, x, y): + """ + :param x: shape (batch_size, seq_len, num_edges, input_dim) + :param y: shape (batch_size, horizon, num_edges, input_dim) + :returns x shape (seq_len, batch_size, num_edges, input_dim) + y shape (horizon, batch_size, num_edges, input_dim) + """ + x = torch.from_numpy(x).float() + y = torch.from_numpy(y).float() + self._logger.debug("X: {}".format(x.size())) + self._logger.debug("y: {}".format(y.size())) + x = x.permute(1, 0, 2, 3) + y = y.permute(1, 0, 2, 3) + return x, y + + def _get_x_y_in_correct_dims(self, x, y): + """ + :param x: shape (seq_len, batch_size, num_edges, input_dim) + :param y: shape (horizon, batch_size, num_edges, input_dim) + :return: x: shape (seq_len, batch_size, num_edges * input_dim) + y: shape (horizon, batch_size, num_edges * output_dim) + """ + batch_size = x.size(1) + self._logger.debug("size of x {}".format(x.size())) + x = x.view(self.seq_len, batch_size, self.num_edges * self.input_dim) + y = y[..., :self.output_dim].view(self.horizon, batch_size, + self.num_edges * self.output_dim) + return x, y + + def _compute_loss(self, y_true, y_predicted): + y_true = self.standard_scaler.inverse_transform(y_true) + y_predicted = self.standard_scaler.inverse_transform(y_predicted) + return masked_mae_loss(y_predicted, y_true) + + def _compute_loss_eval(self, y_true, y_predicted): + y_true = self.standard_scaler.inverse_transform(y_true) + y_predicted = self.standard_scaler.inverse_transform(y_predicted) + return masked_mae_loss(y_predicted, y_true).item(), masked_mape_loss(y_predicted, y_true).item(), masked_rmse_loss(y_predicted, y_true).item() + + def evaluate(self, dataset='val', batches_seen=0, save=False): + """ + Computes mae rmse mape loss and the predict if save + :return: mean L1Loss + """ + with torch.no_grad(): + self.stden_model.eval() + + val_iterator = self._data['{}_loader'.format(dataset)].get_iterator() + mae_losses = [] + mape_losses = [] + rmse_losses = [] + y_dict = None + + if(save): + y_truths = [] + y_preds = [] + + for _, (x, y) in enumerate(val_iterator): + x, y = self._prepare_data(x, y) + + output, fe = self.stden_model(x) + mae, mape, rmse = self._compute_loss_eval(y, output) + mae_losses.append(mae) + mape_losses.append(mape) + rmse_losses.append(rmse) + + if(save): + y_truths.append(y.cpu()) + y_preds.append(output.cpu()) + + mean_loss = { + 'mae': np.mean(mae_losses), + 'mape': np.mean(mape_losses), + 'rmse': np.mean(rmse_losses) + } + + self._logger.info('Evaluation: - mae - {:.4f} - mape - {:.4f} - rmse - {:.4f}'.format(mean_loss['mae'], mean_loss['mape'], mean_loss['rmse'])) + self._writer.add_scalar('{} loss'.format(dataset), mean_loss['mae'], batches_seen) + + if(save): + y_preds = np.concatenate(y_preds, axis=1) + y_truths = np.concatenate(y_truths, axis=1) # concatenate on batch dimension + + y_truths_scaled = [] + y_preds_scaled = [] + # self._logger.debug("y_preds shape: {}, y_truth shape {}".format(y_preds.shape, y_truths.shape)) + for t in range(y_preds.shape[0]): + y_truth = self.standard_scaler.inverse_transform(y_truths[t]) + y_pred = self.standard_scaler.inverse_transform(y_preds[t]) + y_truths_scaled.append(y_truth) + y_preds_scaled.append(y_pred) + + y_preds_scaled = np.stack(y_preds_scaled) + y_truths_scaled = np.stack(y_truths_scaled) + + y_dict = {'prediction': y_preds_scaled, 'truth': y_truths_scaled} + + # save_dir = self._data_kwargs.get('dataset_dir', 'data') + # save_path = os.path.join(save_dir, 'pred.npz') + # np.savez(save_path, prediction=y_preds_scaled, turth=y_truths_scaled) + + return mean_loss['mae'], y_dict + + def eval_more(self, dataset='val', save=False, seq_len=[3, 6, 9, 12], extract_latent=False): + """ + Computes mae rmse mape loss and the prediction if `save` is set True. + """ + self._logger.info('Model mode: Evaluation') + with torch.no_grad(): + self.stden_model.eval() + + val_iterator = self._data['{}_loader'.format(dataset)].get_iterator() + mae_losses = [] + mape_losses = [] + rmse_losses = [] + + if(save): + y_truths = [] + y_preds = [] + + if(extract_latent): + latents = [] + + # used for nfe + c = [] + for _, (x, y) in enumerate(val_iterator): + x, y = self._prepare_data(x, y) + + output, fe = self.stden_model(x) + mae, mape, rmse = [], [], [] + for seq in seq_len: + _mae, _mape, _rmse = self._compute_loss_eval(y[seq-1], output[seq-1]) + mae.append(_mae) + mape.append(_mape) + rmse.append(_rmse) + mae_losses.append(mae) + mape_losses.append(mape) + rmse_losses.append(rmse) + c.append([*fe, np.mean(mae)]) + + if(save): + y_truths.append(y.cpu()) + y_preds.append(output.cpu()) + + if(extract_latent): + latents.append(self.stden_model.latent_feat.cpu()) + + mean_loss = { + 'mae': np.mean(mae_losses, axis=0), + 'mape': np.mean(mape_losses, axis=0), + 'rmse': np.mean(rmse_losses, axis=0) + } + + for i, seq in enumerate(seq_len): + self._logger.info('Evaluation seq {}: - mae - {:.4f} - mape - {:.4f} - rmse - {:.4f}'.format( + seq, mean_loss['mae'][i], mean_loss['mape'][i], mean_loss['rmse'][i])) + + if(save): + # shape (horizon, num_sapmles, feat_dim) + y_preds = np.concatenate(y_preds, axis=1) + y_truths = np.concatenate(y_truths, axis=1) # concatenate on batch dimension + y_preds_scaled = self.standard_scaler.inverse_transform(y_preds) + y_truths_scaled = self.standard_scaler.inverse_transform(y_truths) + + save_dir = self._data_kwargs.get('dataset_dir', 'data') + save_path = os.path.join(save_dir, 'pred_{}_{}.npz'.format(self.experimentID, self._epoch_num)) + np.savez_compressed(save_path, prediction=y_preds_scaled, turth=y_truths_scaled) + + if(extract_latent): + # concatenate on batch dimension + latents = np.concatenate(latents, axis=1) + # Shape of latents (horizon, num_samples, self.num_edges * self.output_dim) + + save_dir = self._data_kwargs.get('dataset_dir', 'data') + filter_type = self._model_kwargs.get('filter_type', 'unknown') + save_path = os.path.join(save_dir, '{}_latent_{}_{}.npz'.format(filter_type, self.experimentID, self._epoch_num)) + np.savez_compressed(save_path, latent=latents) + + if bool(self._model_kwargs.get('nfe', False)): + res = pd.DataFrame(c, columns=['nfe', 'time', 'err']) + res.index.name = 'iter' + filter_type = self._model_kwargs.get('filter_type', 'unknown') + atol = float(self._model_kwargs.get('odeint_atol', 1e-5)) + rtol = float(self._model_kwargs.get('odeint_rtol', 1e-5)) + nfe_file = os.path.join( + self._data_kwargs.get('dataset_dir', 'data'), + 'nfe_{}_a{}_r{}.pkl'.format(filter_type, int(atol*1e5), int(rtol*1e5))) + res.to_pickle(nfe_file) + \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..99721e3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +scipy>=1.5.2 +numpy>=1.19.1 +pandas>=1.1.5 +pyyaml>=5.3.1 +pytorch>=1.7.1 +future>=0.18.2 +torchdiffeq>=0.2.0 \ No newline at end of file diff --git a/stden_eval.py b/stden_eval.py new file mode 100644 index 0000000..6032141 --- /dev/null +++ b/stden_eval.py @@ -0,0 +1,43 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import yaml + +from lib.utils import load_graph_data +from model.stden_supervisor import STDENSupervisor + +import numpy as np +import torch + +def main(args): + with open(args.config_filename) as f: + supervisor_config = yaml.load(f) + + graph_pkl_filename = supervisor_config['data'].get('graph_pkl_filename') + adj_mx = load_graph_data(graph_pkl_filename) + + supervisor = STDENSupervisor(adj_mx=adj_mx, **supervisor_config) + + horizon = supervisor_config['model'].get('horizon') + extract_latent = supervisor_config['model'].get('save_latent') + supervisor.eval_more(dataset='test', + save=args.save_pred, + seq_len=np.arange(1, horizon+1, 1), + extract_latent=extract_latent) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--config_filename', default=None, type=str, + help='Configuration filename for restoring the model.') + parser.add_argument('--use_cpu_only', default=False, type=bool, help='Set to true to only use cpu.') + parser.add_argument('-r', '--random_seed', type=int, default=2021, help="Random seed for reproduction.") + parser.add_argument('--save_pred', action='store_true', help='Save the prediction.') + args = parser.parse_args() + + torch.manual_seed(args.random_seed) + np.random.seed(args.random_seed) + + main(args) diff --git a/stden_train.py b/stden_train.py new file mode 100644 index 0000000..a7af700 --- /dev/null +++ b/stden_train.py @@ -0,0 +1,37 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import yaml + +from lib.utils import load_graph_data +from model.stden_supervisor import STDENSupervisor + +import numpy as np +import torch + +def main(args): + with open(args.config_filename) as f: + supervisor_config = yaml.load(f) + + graph_pkl_filename = supervisor_config['data'].get('graph_pkl_filename') + adj_mx = load_graph_data(graph_pkl_filename) + + supervisor = STDENSupervisor(adj_mx=adj_mx, **supervisor_config) + + supervisor.train() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--config_filename', default=None, type=str, + help='Configuration filename for restoring the model.') + parser.add_argument('--use_cpu_only', default=False, type=bool, help='Set to true to only use cpu.') + parser.add_argument('-r', '--random_seed', type=int, default=2021, help="Random seed for reproduction.") + args = parser.parse_args() + + torch.manual_seed(args.random_seed) + np.random.seed(args.random_seed) + + main(args) diff --git a/test.ipynb b/test.ipynb new file mode 100644 index 0000000..021ba80 --- /dev/null +++ b/test.ipynb @@ -0,0 +1,535 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([[ 0, 1, 2, 3, 4],\n", + " [ 0, 6, 7, 8, 9],\n", + " [ 0, 0, 12, 13, 14],\n", + " [ 0, 0, 0, 18, 19],\n", + " [ 0, 0, 0, 0, 24]]),\n", + " array([[ 0, 0, 0, 0, 0],\n", + " [ 5, 6, 0, 0, 0],\n", + " [10, 11, 12, 0, 0],\n", + " [15, 16, 17, 18, 0],\n", + " [20, 21, 22, 23, 24]]))" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "m = np.arange(0, 25).reshape((5, 5))\n", + "\n", + "out = np.triu(m)\n", + "inp = np.tril(m)\n", + "out, inp" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[1, 0, 1, 0, 0],\n", + " [1, 0, 0, 1, 1],\n", + " [1, 0, 0, 1, 0],\n", + " [1, 1, 1, 0, 0],\n", + " [1, 0, 1, 1, 1]])" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "adj = np.random.randint(0, 2, size=(5, 5))\n", + "adj" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[ 0, 0, 2, 0, 0],\n", + " [ 5, 0, 0, 8, 9],\n", + " [10, 0, 0, 13, 0],\n", + " [15, 16, 17, 0, 0],\n", + " [20, 0, 22, 23, 48]])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "((inp + out) * adj)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[1],\n", + " [2],\n", + " [3]])" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a = torch.tensor([1, 2, 3])\n", + "a.unsqueeze_(-1)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 5, 5])\n" + ] + }, + { + "ename": "IndexError", + "evalue": "Dimension out of range (expected to be in range of [-1, 0], but got 1)", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mIndexError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[0mr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0munsqueeze_\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 4\u001b[1;33m \u001b[0mr\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mr\u001b[0m \u001b[1;33m>\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstart_dim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mend_dim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[1;31mIndexError\u001b[0m: Dimension out of range (expected to be in range of [-1, 0], but got 1)" + ] + } + ], + "source": [ + "r = torch.tensor(((inp + out) * adj))\n", + "r.unsqueeze_(0)\n", + "print(r.shape)\n", + "r[r > 0].flatten(start_dim=0, end_dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "r = r.repeat((2, 1, 1))" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 0, 0, 2, 0, 0, 5, 0, 0, 8, 9, 10, 0, 0, 13, 0, 15, 16, 17,\n", + " 0, 0, 20, 0, 22, 23, 48],\n", + " [ 0, 0, 2, 0, 0, 5, 0, 0, 8, 9, 10, 0, 0, 13, 0, 15, 16, 17,\n", + " 0, 0, 20, 0, 22, 23, 48]], dtype=torch.int32)" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "r = torch.flatten(r, start_dim=1, end_dim=-1)\n", + "r[r>0]" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 2, 5, 8, 9, 10, 13, 15, 16, 17, 20, 22, 23, 48],\n", + " [ 2, 5, 8, 9, 10, 13, 15, 16, 17, 20, 22, 23, 48]],\n", + " dtype=torch.int32)" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "r[r > 0].reshape(2, -1)" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [], + "source": [ + "import torch.nn as nn" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "metadata": {}, + "outputs": [], + "source": [ + "class GraphGrad(torch.nn.Module):\n", + " def __init__(self, adj_mx):\n", + " \"\"\"Graph gradient operator that transform functions on nodes to functions on edges.\n", + " \"\"\"\n", + " super(GraphGrad, self).__init__()\n", + " self.adj_mx = adj_mx\n", + " self.grad = self._grad(adj_mx)\n", + " \n", + " @staticmethod\n", + " def _grad(adj_mx):\n", + " \"\"\"Fetch the graph gradient operator.\"\"\"\n", + " num_nodes = adj_mx.size()[-1]\n", + "\n", + " num_edges = (adj_mx > 0.).sum()\n", + " grad = torch.zeros(num_nodes, num_edges)\n", + " e = 0\n", + " for i in range(num_nodes):\n", + " for j in range(num_nodes):\n", + " if adj_mx[i, j] == 0:\n", + " continue\n", + "\n", + " grad[i, e] = 1.\n", + " grad[j, e] = -1.\n", + " e += 1\n", + " return grad\n", + "\n", + " def forward(self, z):\n", + " \"\"\"Transform z with shape `(..., num_nodes)` to f with shape `(..., num_edges)`.\n", + " \"\"\"\n", + " return torch.matmul(z, self.grad)" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[1, 0, 1, 0, 0],\n", + " [1, 0, 0, 1, 1],\n", + " [1, 0, 0, 1, 0],\n", + " [1, 1, 1, 0, 0],\n", + " [1, 0, 1, 1, 1]])" + ] + }, + "execution_count": 68, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "adj" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([5, 14])" + ] + }, + "execution_count": 84, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gg = GraphGrad(torch.tensor(adj))\n", + "grad = gg.grad\n", + "grad.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([14, 5])" + ] + }, + "execution_count": 94, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grad.transpose(-1, -2).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "14" + ] + }, + "execution_count": 97, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grad.size(-1)" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(5, 5)" + ] + }, + "execution_count": 73, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inp.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0.],\n", + " [ -5., 5., 1., 6., 6., -5., 0., -5., -6., 0., -5., 0.,\n", + " 0., 0.],\n", + " [-10., -2., 1., 11., 11., 2., 12., -10., -11., -12., -10., -12.,\n", + " 0., 0.],\n", + " [-15., -2., 1., -2., 16., 2., -1., 3., 2., 1., -15., -17.,\n", + " -18., 0.],\n", + " [-20., -2., 1., -2., -3., 2., -1., 3., 2., 1., 4., 2.,\n", + " 1., -24.]])" + ] + }, + "execution_count": 88, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gg(torch.tensor(inp, dtype=torch.float32))" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([5, 5])" + ] + }, + "execution_count": 80, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.tensor(inp).shape \n", + "# (grad_T.T)" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1;31mDocstring:\u001b[0m\n", + "matmul(input, other, *, out=None) -> Tensor\n", + "\n", + "Matrix product of two tensors.\n", + "\n", + "The behavior depends on the dimensionality of the tensors as follows:\n", + "\n", + "- If both tensors are 1-dimensional, the dot product (scalar) is returned.\n", + "- If both arguments are 2-dimensional, the matrix-matrix product is returned.\n", + "- If the first argument is 1-dimensional and the second argument is 2-dimensional,\n", + " a 1 is prepended to its dimension for the purpose of the matrix multiply.\n", + " After the matrix multiply, the prepended dimension is removed.\n", + "- If the first argument is 2-dimensional and the second argument is 1-dimensional,\n", + " the matrix-vector product is returned.\n", + "- If both arguments are at least 1-dimensional and at least one argument is\n", + " N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first\n", + " argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the\n", + " batched matrix multiply and removed after. If the second argument is 1-dimensional, a\n", + " 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after.\n", + " The non-matrix (i.e. batch) dimensions are :ref:`broadcasted ` (and thus\n", + " must be broadcastable). For example, if :attr:`input` is a\n", + " :math:`(j \\times 1 \\times n \\times n)` tensor and :attr:`other` is a :math:`(k \\times n \\times n)`\n", + " tensor, :attr:`out` will be a :math:`(j \\times k \\times n \\times n)` tensor.\n", + "\n", + " Note that the broadcasting logic only looks at the batch dimensions when determining if the inputs\n", + " are broadcastable, and not the matrix dimensions. For example, if :attr:`input` is a\n", + " :math:`(j \\times 1 \\times n \\times m)` tensor and :attr:`other` is a :math:`(k \\times m \\times p)`\n", + " tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the\n", + " matrix dimensions) are different. :attr:`out` will be a :math:`(j \\times k \\times n \\times p)` tensor.\n", + "\n", + "This operator supports :ref:`TensorFloat32`.\n", + "\n", + ".. note::\n", + "\n", + " The 1-dimensional dot product version of this function does not support an :attr:`out` parameter.\n", + "\n", + "Arguments:\n", + " input (Tensor): the first tensor to be multiplied\n", + " other (Tensor): the second tensor to be multiplied\n", + "\n", + "Keyword args:\n", + " out (Tensor, optional): the output tensor.\n", + "\n", + "Example::\n", + "\n", + " >>> # vector x vector\n", + " >>> tensor1 = torch.randn(3)\n", + " >>> tensor2 = torch.randn(3)\n", + " >>> torch.matmul(tensor1, tensor2).size()\n", + " torch.Size([])\n", + " >>> # matrix x vector\n", + " >>> tensor1 = torch.randn(3, 4)\n", + " >>> tensor2 = torch.randn(4)\n", + " >>> torch.matmul(tensor1, tensor2).size()\n", + " torch.Size([3])\n", + " >>> # batched matrix x broadcasted vector\n", + " >>> tensor1 = torch.randn(10, 3, 4)\n", + " >>> tensor2 = torch.randn(4)\n", + " >>> torch.matmul(tensor1, tensor2).size()\n", + " torch.Size([10, 3])\n", + " >>> # batched matrix x batched matrix\n", + " >>> tensor1 = torch.randn(10, 3, 4)\n", + " >>> tensor2 = torch.randn(10, 4, 5)\n", + " >>> torch.matmul(tensor1, tensor2).size()\n", + " torch.Size([10, 3, 5])\n", + " >>> # batched matrix x broadcasted matrix\n", + " >>> tensor1 = torch.randn(10, 3, 4)\n", + " >>> tensor2 = torch.randn(4, 5)\n", + " >>> torch.matmul(tensor1, tensor2).size()\n", + " torch.Size([10, 3, 5])\n", + "\u001b[1;31mType:\u001b[0m builtin_function_or_method\n" + ] + } + ], + "source": [ + "torch.matmul?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "interpreter": { + "hash": "1b89aa55be347d0b8cc51b3a166e8002614a385bd8cff32165269c80e70c12a7" + }, + "kernelspec": { + "display_name": "Python 3.8.5 64-bit ('base': conda)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}