Project-I/data/get_adj.py

221 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import csv
import os
import numpy as np
import pandas as pd
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import norm
import scipy.sparse as sp
import torch
def get_adj(args):
dataset_path = './data'
match args['num_nodes']:
case 358:
dataset_name = 'PEMS03'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS03.csv')
id = os.path.join(dataset_path, dataset_name, 'PEMS03.txt')
A, adj = load_adj(args['num_nodes'], adj_path, id_filename=id)
case 307:
dataset_name = 'PEMS04'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS04.csv')
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
case 883:
dataset_name = 'PEMS07'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS07.csv')
A, adj = load_adj(args['num_nodes'], adj_path)
case 170:
dataset_name = 'PEMS08'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS08.csv')
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
return adj
def get_gso(args):
dataset_path = './data'
match args['num_nodes']:
case 358:
dataset_name = 'PEMS03'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS03.csv')
id = os.path.join(dataset_path, dataset_name, 'PEMS03.txt')
A, adj = load_adj(args['num_nodes'], adj_path, id_filename=id)
case 307:
dataset_name = 'PEMS04'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS04.csv')
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
case 883:
dataset_name = 'PEMS07'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS07.csv')
A, adj = load_adj(args['num_nodes'], adj_path)
case 170:
dataset_name = 'PEMS08'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS08.csv')
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
gso = calc_gso(adj, args['gso_type'])
if args['graph_conv_type'] == 'cheb_graph_conv':
gso = calc_chebynet_gso(gso)
gso = gso.toarray()
gso = gso.astype(dtype=np.float32)
gso = torch.from_numpy(gso).to(args['device'])
return gso
def load_adj(num_nodes, adj_path, id_filename=None, std=False):
'''
Parameters
----------
adj_path: str, path of the csv file contains edges information
num_nodes: int, the number of vertices
id_filename: str, optional, path of the file containing node IDs (if not starting from 0)
std: bool, if True, normalize the cost values in the CSV file using Gaussian normalization
Returns
----------
A: np.ndarray, adjacency matrix
distanceA: np.ndarray, distance matrix (normalized if std=True)
'''
if 'npy' in adj_path:
adj_mx = np.load(adj_path)
return adj_mx, None
else:
A = np.zeros((int(num_nodes), int(num_nodes)), dtype=np.float32)
distanceA = np.zeros((int(num_nodes), int(num_nodes)), dtype=np.float32)
# 如果提供了id_filename说明节点ID不是从0开始的需要重新映射
if id_filename:
with open(id_filename, 'r') as f:
id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))}
with open(adj_path, 'r') as f:
f.readline() # 略过表头那一行
reader = csv.reader(f)
costs = [] # 用于收集所有cost值
for row in reader:
if len(row) != 3:
continue
i, j, distance = int(row[0]), int(row[1]), float(row[2])
A[id_dict[i], id_dict[j]] = 1
# 确保距离值为正
distance = max(distance, 1e-6)
costs.append(distance) # 收集cost值
distanceA[id_dict[i], id_dict[j]] = distance
else: # 如果没有提供id_filename说明节点ID是从0开始的
with open(adj_path, 'r') as f:
f.readline() # 略过表头那一行
reader = csv.reader(f)
costs = [] # 用于收集所有cost值
for row in reader:
if len(row) != 3:
continue
i, j, distance = int(row[0]), int(row[1]), float(row[2])
A[i, j] = 1
# 确保距离值为正
distance = max(distance, 1e-6)
costs.append(distance) # 收集cost值
distanceA[i, j] = distance
# 如果std=True对CSV中的所有cost值进行高斯正态分布标准化
if std:
mean_cost = np.mean(costs) # 计算cost值的均值
std_cost = np.std(costs) # 计算cost值的标准差
for idx in np.ndindex(distanceA.shape): # 遍历矩阵
if distanceA[idx] > 0: # 只对非零元素进行标准化
normalized_value = (distanceA[idx] - mean_cost) / std_cost
# 确保标准化后的值为正
normalized_value = max(normalized_value, 1e-6)
distanceA[idx] = normalized_value
# 确保矩阵中没有零行
row_sums = distanceA.sum(axis=1)
zero_rows = np.where(row_sums == 0)[0]
for row in zero_rows:
distanceA[row, :] = 1e-6 # 将零行替换为一个非零的默认值
return A, distanceA
def calc_gso(dir_adj, gso_type):
n_vertex = dir_adj.shape[0]
if not sp.issparse(dir_adj):
dir_adj = sp.csc_matrix(dir_adj)
elif dir_adj.format != 'csc':
dir_adj = dir_adj.tocsc()
id = sp.identity(n_vertex, format='csc')
# Symmetrizing an adjacency matrix
adj = dir_adj + dir_adj.T.multiply(dir_adj.T > dir_adj) - dir_adj.multiply(dir_adj.T > dir_adj)
# adj = 0.5 * (dir_adj + dir_adj.transpose())
if gso_type in ['sym_renorm_adj', 'rw_renorm_adj', 'sym_renorm_lap', 'rw_renorm_lap']:
adj = adj + id
if gso_type in ['sym_norm_adj', 'sym_renorm_adj', 'sym_norm_lap', 'sym_renorm_lap']:
row_sum = adj.sum(axis=1).A1
# Check for zero or negative values in row_sum
if np.any(row_sum <= 0):
raise ValueError(
"Row sum contains zero or negative values, which is not allowed for symmetric normalization.")
row_sum_inv_sqrt = np.power(row_sum, -0.5)
row_sum_inv_sqrt[np.isinf(row_sum_inv_sqrt)] = 0. # Handle inf values
deg_inv_sqrt = sp.diags(row_sum_inv_sqrt, format='csc')
# A_{sym} = D^{-0.5} * A * D^{-0.5}
sym_norm_adj = deg_inv_sqrt.dot(adj).dot(deg_inv_sqrt)
if gso_type in ['sym_norm_lap', 'sym_renorm_lap']:
sym_norm_lap = id - sym_norm_adj
gso = sym_norm_lap
else:
gso = sym_norm_adj
elif gso_type in ['rw_norm_adj', 'rw_renorm_adj', 'rw_norm_lap', 'rw_renorm_lap']:
row_sum = np.sum(adj, axis=1).A1
# Check for zero or negative values in row_sum
if np.any(row_sum <= 0):
raise ValueError(
"Row sum contains zero or negative values, which is not allowed for random walk normalization.")
row_sum_inv = np.power(row_sum, -1)
row_sum_inv[np.isinf(row_sum_inv)] = 0. # Handle inf values
deg_inv = sp.diags(row_sum_inv, format='csc')
# A_{rw} = D^{-1} * A
rw_norm_adj = deg_inv.dot(adj)
if gso_type in ['rw_norm_lap', 'rw_renorm_lap']:
rw_norm_lap = id - rw_norm_adj
gso = rw_norm_lap
else:
gso = rw_norm_adj
else:
raise ValueError(f'{gso_type} is not defined.')
# Check for nan or inf in the final result
if np.isnan(gso.data).any() or np.isinf(gso.data).any():
raise ValueError("NaN or Inf detected in the final GSO matrix. Please check the input adjacency matrix.")
return gso
def calc_chebynet_gso(gso):
if sp.issparse(gso) == False:
gso = sp.csc_matrix(gso)
elif gso.format != 'csc':
gso = gso.tocsc()
id = sp.identity(gso.shape[0], format='csc')
# If you encounter a NotImplementedError, please update your scipy version to 1.10.1 or later.
eigval_max = norm(gso, 2)
# If the gso is symmetric or random walk normalized Laplacian,
# then the maximum eigenvalue is smaller than or equals to 2.
if eigval_max >= 2:
gso = gso - id
else:
gso = 2 * gso / eigval_max - id
return gso