TrafficWheel/model/DCRNN/utils.py

154 lines
5.0 KiB
Python
Executable File

import logging
import numpy as np
import os
import pickle
import scipy.sparse as sp
import sys
from scipy.sparse import linalg
def calculate_normalized_laplacian(adj):
"""
# L = D^-1/2 (D-A) D^-1/2 = I - D^-1/2 A D^-1/2
# D = diag(A 1)
:param adj:
:return:
"""
adj = sp.coo_matrix(adj)
d = np.array(adj.sum(1))
d_inv_sqrt = np.power(d, -0.5).flatten()
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.0
d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
normalized_laplacian = (
sp.eye(adj.shape[0])
- adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
)
return normalized_laplacian
def calculate_random_walk_matrix(adj_mx):
adj_mx = sp.coo_matrix(adj_mx)
d = np.array(adj_mx.sum(1))
d_inv = np.power(d, -1).flatten()
d_inv[np.isinf(d_inv)] = 0.0
d_mat_inv = sp.diags(d_inv)
random_walk_mx = d_mat_inv.dot(adj_mx).tocoo()
return random_walk_mx
def calculate_reverse_random_walk_matrix(adj_mx):
return calculate_random_walk_matrix(np.transpose(adj_mx))
def calculate_scaled_laplacian(adj_mx, lambda_max=2, undirected=True):
if undirected:
adj_mx = np.maximum.reduce([adj_mx, adj_mx.T])
L = calculate_normalized_laplacian(adj_mx)
if lambda_max is None:
lambda_max, _ = linalg.eigsh(L, 1, which="LM")
lambda_max = lambda_max[0]
L = sp.csr_matrix(L)
M, _ = L.shape
I = sp.identity(M, format="csr", dtype=L.dtype)
L = (2 / lambda_max * L) - I
return L.astype(np.float32)
def config_logging(log_dir, log_filename="info.log", level=logging.INFO):
# Add file handler and stdout handler
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
# Create the log directory if necessary.
try:
os.makedirs(log_dir)
except OSError:
pass
file_handler = logging.FileHandler(os.path.join(log_dir, log_filename))
file_handler.setFormatter(formatter)
file_handler.setLevel(level=level)
# Add console handler.
console_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(console_formatter)
console_handler.setLevel(level=level)
logging.basicConfig(handlers=[file_handler, console_handler], level=level)
def get_logger(log_dir, name, log_filename="info.log", level=logging.INFO):
logger = logging.getLogger(name)
logger.setLevel(level)
# Add file handler and stdout handler
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
file_handler = logging.FileHandler(os.path.join(log_dir, log_filename))
file_handler.setFormatter(formatter)
# Add console handler.
console_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(console_formatter)
logger.addHandler(file_handler)
logger.addHandler(console_handler)
# Add google cloud log handler
logger.info("Log directory: %s", log_dir)
return logger
def get_total_trainable_parameter_size():
"""
Calculates the total number of trainable parameters in the current graph.
:return:
"""
total_parameters = 0
for variable in tf.trainable_variables():
# shape is an array of tf.Dimension
total_parameters += np.product([x.value for x in variable.get_shape()])
return total_parameters
def load_dataset(dataset_dir, batch_size, test_batch_size=None, **kwargs):
data = {}
for category in ["train", "val", "test"]:
cat_data = np.load(os.path.join(dataset_dir, category + ".npz"))
data["x_" + category] = cat_data["x"]
data["y_" + category] = cat_data["y"]
scaler = StandardScaler(
mean=data["x_train"][..., 0].mean(), std=data["x_train"][..., 0].std()
)
# Data format
for category in ["train", "val", "test"]:
data["x_" + category][..., 0] = scaler.transform(data["x_" + category][..., 0])
data["y_" + category][..., 0] = scaler.transform(data["y_" + category][..., 0])
data["train_loader"] = DataLoader(
data["x_train"], data["y_train"], batch_size, shuffle=True
)
data["val_loader"] = DataLoader(
data["x_val"], data["y_val"], test_batch_size, shuffle=False
)
data["test_loader"] = DataLoader(
data["x_test"], data["y_test"], test_batch_size, shuffle=False
)
data["scaler"] = scaler
return data
def load_graph_data(pkl_filename):
sensor_ids, sensor_id_to_ind, adj_mx = load_pickle(pkl_filename)
return sensor_ids, sensor_id_to_ind, adj_mx
def load_pickle(pickle_file):
try:
with open(pickle_file, "rb") as f:
pickle_data = pickle.load(f)
except UnicodeDecodeError as e:
with open(pickle_file, "rb") as f:
pickle_data = pickle.load(f, encoding="latin1")
except Exception as e:
print("Unable to load data ", pickle_file, ":", e)
raise
return pickle_data