DCRNN/lib/utils.py

213 lines
7.1 KiB
Python

import logging
import numpy as np
import os
import pickle
import scipy.sparse as sp
import sys
import tensorflow as tf
from scipy.sparse import linalg
class DataLoader(object):
def __init__(self, xs, ys, batch_size, pad_with_last_sample=True, shuffle=False):
"""
:param xs:
:param ys:
:param batch_size:
:param pad_with_last_sample: pad with the last sample to make number of samples divisible to batch_size.
"""
self.batch_size = batch_size
self.current_ind = 0
if pad_with_last_sample:
num_padding = (batch_size - (len(xs) % batch_size)) % batch_size
x_padding = np.repeat(xs[-1:], num_padding, axis=0)
y_padding = np.repeat(ys[-1:], num_padding, axis=0)
xs = np.concatenate([xs, x_padding], axis=0)
ys = np.concatenate([ys, y_padding], axis=0)
self.size = len(xs)
self.num_batch = int(self.size // self.batch_size)
if shuffle:
permutation = np.random.permutation(self.size)
xs, ys = xs[permutation], ys[permutation]
self.xs = xs
self.ys = ys
def get_iterator(self):
self.current_ind = 0
def _wrapper():
while self.current_ind < self.num_batch:
start_ind = self.batch_size * self.current_ind
end_ind = min(self.size, self.batch_size * (self.current_ind + 1))
x_i = self.xs[start_ind: end_ind, ...]
y_i = self.ys[start_ind: end_ind, ...]
yield (x_i, y_i)
self.current_ind += 1
return _wrapper()
class StandardScaler:
"""
Standard the input
"""
def __init__(self, mean, std):
self.mean = mean
self.std = std
def transform(self, data):
return (data - self.mean) / self.std
def inverse_transform(self, data):
return (data * self.std) + self.mean
def add_simple_summary(writer, names, values, global_step):
"""
Writes summary for a list of scalars.
:param writer:
:param names:
:param values:
:param global_step:
:return:
"""
for name, value in zip(names, values):
summary = tf.Summary()
summary_value = summary.value.add()
summary_value.simple_value = value
summary_value.tag = name
writer.add_summary(summary, global_step)
def calculate_normalized_laplacian(adj):
"""
# L = D^-1/2 (D-A) D^-1/2 = I - D^-1/2 A D^-1/2
# D = diag(A 1)
:param adj:
:return:
"""
adj = sp.coo_matrix(adj)
d = np.array(adj.sum(1))
d_inv_sqrt = np.power(d, -0.5).flatten()
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
normalized_laplacian = sp.eye(adj.shape[0]) - adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
return normalized_laplacian
def calculate_random_walk_matrix(adj_mx):
adj_mx = sp.coo_matrix(adj_mx)
d = np.array(adj_mx.sum(1))
d_inv = np.power(d, -1).flatten()
d_inv[np.isinf(d_inv)] = 0.
d_mat_inv = sp.diags(d_inv)
random_walk_mx = d_mat_inv.dot(adj_mx).tocoo()
return random_walk_mx
def calculate_reverse_random_walk_matrix(adj_mx):
return calculate_random_walk_matrix(np.transpose(adj_mx))
def calculate_scaled_laplacian(adj_mx, lambda_max=2, undirected=True):
if undirected:
adj_mx = np.maximum.reduce([adj_mx, adj_mx.T])
L = calculate_normalized_laplacian(adj_mx)
if lambda_max is None:
lambda_max, _ = linalg.eigsh(L, 1, which='LM')
lambda_max = lambda_max[0]
L = sp.csr_matrix(L)
M, _ = L.shape
I = sp.identity(M, format='csr', dtype=L.dtype)
L = (2 / lambda_max * L) - I
return L.astype(np.float32)
def config_logging(log_dir, log_filename='info.log', level=logging.INFO):
# Add file handler and stdout handler
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Create the log directory if necessary.
try:
os.makedirs(log_dir)
except OSError:
pass
file_handler = logging.FileHandler(os.path.join(log_dir, log_filename))
file_handler.setFormatter(formatter)
file_handler.setLevel(level=level)
# Add console handler.
console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(console_formatter)
console_handler.setLevel(level=level)
logging.basicConfig(handlers=[file_handler, console_handler], level=level)
def get_logger(log_dir, name, log_filename='info.log', level=logging.INFO):
logger = logging.getLogger(name)
logger.setLevel(level)
# Add file handler and stdout handler
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler(os.path.join(log_dir, log_filename))
file_handler.setFormatter(formatter)
# Add console handler.
console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(console_formatter)
logger.addHandler(file_handler)
logger.addHandler(console_handler)
# Add google cloud log handler
logger.info('Log directory: %s', log_dir)
return logger
def get_total_trainable_parameter_size():
"""
Calculates the total number of trainable parameters in the current graph.
:return:
"""
total_parameters = 0
for variable in tf.trainable_variables():
# shape is an array of tf.Dimension
total_parameters += np.product([x.value for x in variable.get_shape()])
return total_parameters
def load_dataset(dataset_dir, batch_size, test_batch_size=None, **kwargs):
data = {}
for category in ['train', 'val', 'test']:
cat_data = np.load(os.path.join(dataset_dir, category + '.npz'))
data['x_' + category] = cat_data['x']
data['y_' + category] = cat_data['y']
scaler = StandardScaler(mean=data['x_train'][..., 0].mean(), std=data['x_train'][..., 0].std())
# Data format
for category in ['train', 'val', 'test']:
data['x_' + category][..., 0] = scaler.transform(data['x_' + category][..., 0])
data['y_' + category][..., 0] = scaler.transform(data['y_' + category][..., 0])
data['train_loader'] = DataLoader(data['x_train'], data['y_train'], batch_size, shuffle=True)
data['val_loader'] = DataLoader(data['x_val'], data['y_val'], test_batch_size, shuffle=False)
data['test_loader'] = DataLoader(data['x_test'], data['y_test'], test_batch_size, shuffle=False)
data['scaler'] = scaler
return data
def load_graph_data(pkl_filename):
sensor_ids, sensor_id_to_ind, adj_mx = load_pickle(pkl_filename)
return sensor_ids, sensor_id_to_ind, adj_mx
def load_pickle(pickle_file):
try:
with open(pickle_file, 'rb') as f:
pickle_data = pickle.load(f)
except UnicodeDecodeError as e:
with open(pickle_file, 'rb') as f:
pickle_data = pickle.load(f, encoding='latin1')
except Exception as e:
print('Unable to load data ', pickle_file, ':', e)
raise
return pickle_data