STDEN/lib/utils.py

228 lines
7.8 KiB
Python

import logging
import numpy as np
import os
import time
import scipy.sparse as sp
import sys
import torch
import torch.nn as nn
class DataLoader(object):
def __init__(self, xs, ys, batch_size, pad_with_last_sample=True, shuffle=False):
"""
:param xs:
:param ys:
:param batch_size:
:param pad_with_last_sample: pad with the last sample to make number of samples divisible to batch_size.
"""
self.batch_size = batch_size
self.current_ind = 0
if pad_with_last_sample:
num_padding = (batch_size - (len(xs) % batch_size)) % batch_size
x_padding = np.repeat(xs[-1:], num_padding, axis=0)
y_padding = np.repeat(ys[-1:], num_padding, axis=0)
xs = np.concatenate([xs, x_padding], axis=0)
ys = np.concatenate([ys, y_padding], axis=0)
self.size = len(xs)
self.num_batch = int(self.size // self.batch_size)
if shuffle:
permutation = np.random.permutation(self.size)
xs, ys = xs[permutation], ys[permutation]
self.xs = xs
self.ys = ys
def get_iterator(self):
self.current_ind = 0
def _wrapper():
while self.current_ind < self.num_batch:
start_ind = self.batch_size * self.current_ind
end_ind = min(self.size, self.batch_size * (self.current_ind + 1))
x_i = self.xs[start_ind: end_ind, ...]
y_i = self.ys[start_ind: end_ind, ...]
yield (x_i, y_i)
self.current_ind += 1
return _wrapper()
class StandardScaler:
"""
Standard the input
"""
def __init__(self, mean, std):
self.mean = mean
self.std = std
def transform(self, data):
return (data - self.mean) / self.std
def inverse_transform(self, data):
return (data * self.std) + self.mean
def calculate_random_walk_matrix(adj_mx):
adj_mx = sp.coo_matrix(adj_mx)
d = np.array(adj_mx.sum(1))
d_inv = np.power(d, -1).flatten()
d_inv[np.isinf(d_inv)] = 0.
d_mat_inv = sp.diags(d_inv)
random_walk_mx = d_mat_inv.dot(adj_mx).tocoo()
return random_walk_mx
def config_logging(log_dir, log_filename='info.log', level=logging.INFO):
# Add file handler and stdout handler
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Create the log directory if necessary.
try:
os.makedirs(log_dir)
except OSError:
pass
file_handler = logging.FileHandler(os.path.join(log_dir, log_filename))
file_handler.setFormatter(formatter)
file_handler.setLevel(level=level)
# Add console handler.
console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(console_formatter)
console_handler.setLevel(level=level)
logging.basicConfig(handlers=[file_handler, console_handler], level=level)
def get_logger(log_dir, name, log_filename='info.log', level=logging.INFO):
logger = logging.getLogger(name)
logger.setLevel(level)
# Add file handler and stdout handler
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler(os.path.join(log_dir, log_filename))
file_handler.setFormatter(formatter)
# Add console handler.
console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(console_formatter)
logger.addHandler(file_handler)
logger.addHandler(console_handler)
# Add google cloud log handler
logger.info('Log directory: %s', log_dir)
return logger
def get_log_dir(kwargs):
log_dir = kwargs['train'].get('log_dir')
if log_dir is None:
batch_size = kwargs['data'].get('batch_size')
filter_type = kwargs['model'].get('filter_type')
gcn_step = kwargs['model'].get('gcn_step')
horizon = kwargs['model'].get('horizon')
latent_dim = kwargs['model'].get('latent_dim')
n_traj_samples = kwargs['model'].get('n_traj_samples')
ode_method = kwargs['model'].get('ode_method')
seq_len = kwargs['model'].get('seq_len')
rnn_units = kwargs['model'].get('rnn_units')
recg_type = kwargs['model'].get('recg_type')
if filter_type == 'unkP':
filter_type_abbr = 'UP'
elif filter_type == 'IncP':
filter_type_abbr = 'NV'
else:
filter_type_abbr = 'DF'
run_id = 'STDEN_%s-%d_%s-%d_L-%d_N-%d_M-%s_bs-%d_%d-%d_%s/' % (
recg_type, rnn_units, filter_type_abbr, gcn_step, latent_dim, n_traj_samples, ode_method, batch_size, seq_len, horizon, time.strftime('%m%d%H%M%S'))
base_dir = kwargs.get('log_base_dir')
log_dir = os.path.join(base_dir, run_id)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
return log_dir
def load_dataset(dataset_dir, batch_size, val_batch_size=None, **kwargs):
if('BJ' in dataset_dir):
data = dict(np.load(os.path.join(dataset_dir, 'flow.npz'))) # convert readonly NpzFile to writable dict Object
for category in ['train', 'val', 'test']:
data['x_' + category] = data['x_' + category] #[..., :4] # ignore the time index
else:
data = {}
for category in ['train', 'val', 'test']:
cat_data = np.load(os.path.join(dataset_dir, category + '.npz'))
data['x_' + category] = cat_data['x']
data['y_' + category] = cat_data['y']
scaler = StandardScaler(mean=data['x_train'].mean(), std=data['x_train'].std())
# Data format
for category in ['train', 'val', 'test']:
data['x_' + category] = scaler.transform(data['x_' + category])
data['y_' + category] = scaler.transform(data['y_' + category])
data['train_loader'] = DataLoader(data['x_train'], data['y_train'], batch_size, shuffle=True)
data['val_loader'] = DataLoader(data['x_val'], data['y_val'], val_batch_size, shuffle=False)
data['test_loader'] = DataLoader(data['x_test'], data['y_test'], val_batch_size, shuffle=False)
data['scaler'] = scaler
return data
def load_graph_data(pkl_filename):
adj_mx = np.load(pkl_filename)
return adj_mx
def graph_grad(adj_mx):
"""Fetch the graph gradient operator."""
num_nodes = adj_mx.shape[0]
num_edges = (adj_mx > 0.).sum()
grad = torch.zeros(num_nodes, num_edges)
e = 0
for i in range(num_nodes):
for j in range(num_nodes):
if adj_mx[i, j] == 0:
continue
grad[i, e] = 1.
grad[j, e] = -1.
e += 1
return grad
def init_network_weights(net, std = 0.1):
"""
Just for nn.Linear net.
"""
for m in net.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0, std=std)
nn.init.constant_(m.bias, val=0)
def split_last_dim(data):
last_dim = data.size()[-1]
last_dim = last_dim//2
res = data[..., :last_dim], data[..., last_dim:]
return res
def get_device(tensor):
device = torch.device("cpu")
if tensor.is_cuda:
device = tensor.get_device()
return device
def sample_standard_gaussian(mu, sigma):
device = get_device(mu)
d = torch.distributions.normal.Normal(torch.Tensor([0.]).to(device), torch.Tensor([1.]).to(device))
r = d.sample(mu.size()).squeeze(-1)
return r * sigma.float() + mu.float()
def create_net(n_inputs, n_outputs, n_layers = 0,
n_units = 100, nonlinear = nn.Tanh):
layers = [nn.Linear(n_inputs, n_units)]
for i in range(n_layers):
layers.append(nonlinear())
layers.append(nn.Linear(n_units, n_units))
layers.append(nonlinear())
layers.append(nn.Linear(n_units, n_outputs))
return nn.Sequential(*layers)