From d0af46ea5ffa9512b0756dfb5168f0265310d9ff Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Mon, 10 Nov 2025 11:13:39 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E7=BC=BA=E5=BA=93=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/DCRNN/dcrnn_model.py | 2 +- model/STFGNN/STFGNN.py | 2 +- model/STGCN/models.py | 2 +- model/STMLP/STMLP.py | 2 +- model/ST_SSL/ST-SSL.py | 5 +- model/ST_SSL/ST_SSL.py | 2 +- requirements.txt | 4 +- utils/get_adj.py | 218 +++++++++++++++++++++++++++++++++++++ 8 files changed, 228 insertions(+), 9 deletions(-) create mode 100644 utils/get_adj.py diff --git a/model/DCRNN/dcrnn_model.py b/model/DCRNN/dcrnn_model.py index 6af81a2..3a91e83 100755 --- a/model/DCRNN/dcrnn_model.py +++ b/model/DCRNN/dcrnn_model.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn from model.DCRNN.dcrnn_cell import DCGRUCell -from data.get_adj import get_adj +from utils.get_adj import get_adj class Seq2SeqAttrs: diff --git a/model/STFGNN/STFGNN.py b/model/STFGNN/STFGNN.py index db6e26f..734bb15 100755 --- a/model/STFGNN/STFGNN.py +++ b/model/STFGNN/STFGNN.py @@ -2,7 +2,7 @@ import torch import torch.nn.functional as F import torch.nn as nn -from data.get_adj import get_adj +from utils.get_adj import get_adj class gcn_operation(nn.Module): diff --git a/model/STGCN/models.py b/model/STGCN/models.py index 7674e8a..ad52e17 100755 --- a/model/STGCN/models.py +++ b/model/STGCN/models.py @@ -1,7 +1,7 @@ import torch.nn as nn from model.STGCN import layers -from data.get_adj import get_gso +from utils.get_adj import get_gso class STGCNChebGraphConv(nn.Module): diff --git a/model/STMLP/STMLP.py b/model/STMLP/STMLP.py index 3eb640b..dd7fc89 100644 --- a/model/STMLP/STMLP.py +++ b/model/STMLP/STMLP.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import init -from data.get_adj import get_adj +from utils.get_adj import get_adj import numbers diff --git a/model/ST_SSL/ST-SSL.py b/model/ST_SSL/ST-SSL.py index 1aaa647..096f7e5 100644 --- a/model/ST_SSL/ST-SSL.py +++ b/model/ST_SSL/ST-SSL.py @@ -1,8 +1,7 @@ -import torch import torch.nn as nn from model.ST-SSL.models import STSSL -from model.ST-SSL.layers import STEncoder, MLP -from data.get_adj import get_gso +from model.ST-SSL.layers +from utils.get_adj import get_gso class STSSLModel(nn.Module): def __init__(self, args): diff --git a/model/ST_SSL/ST_SSL.py b/model/ST_SSL/ST_SSL.py index 4a4660b..1302c0e 100644 --- a/model/ST_SSL/ST_SSL.py +++ b/model/ST_SSL/ST_SSL.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from data.get_adj import get_gso +from utils.get_adj import get_gso class STSSLModel(nn.Module): diff --git a/requirements.txt b/requirements.txt index 6f83511..d23694d 100755 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ torchaudio torchdiffeq fastdtw notebook -torchcde \ No newline at end of file +torchcde +einops +transformers \ No newline at end of file diff --git a/utils/get_adj.py b/utils/get_adj.py new file mode 100644 index 0000000..a74b90e --- /dev/null +++ b/utils/get_adj.py @@ -0,0 +1,218 @@ +import csv +import os +import numpy as np +from scipy.sparse.linalg import norm +import scipy.sparse as sp +import torch + +def get_adj(args): + dataset_path = './data' + match args['num_nodes']: + case 358: + dataset_name = 'PEMS03' + adj_path = os.path.join(dataset_path, dataset_name, 'PEMS03.csv') + id = os.path.join(dataset_path, dataset_name, 'PEMS03.txt') + A, adj = load_adj(args['num_nodes'], adj_path, id_filename=id) + case 307: + dataset_name = 'PEMS04' + adj_path = os.path.join(dataset_path, dataset_name, 'PEMS04.csv') + A, adj = load_adj(args['num_nodes'], adj_path, std=True) + case 883: + dataset_name = 'PEMS07' + adj_path = os.path.join(dataset_path, dataset_name, 'PEMS07.csv') + A, adj = load_adj(args['num_nodes'], adj_path) + case 170: + dataset_name = 'PEMS08' + adj_path = os.path.join(dataset_path, dataset_name, 'PEMS08.csv') + A, adj = load_adj(args['num_nodes'], adj_path, std=True) + + return adj + +def get_gso(args): + dataset_path = './data' + match args['num_nodes']: + case 358: + dataset_name = 'PEMS03' + adj_path = os.path.join(dataset_path, dataset_name, 'PEMS03.csv') + id = os.path.join(dataset_path, dataset_name, 'PEMS03.txt') + A, adj = load_adj(args['num_nodes'], adj_path, id_filename=id) + case 307: + dataset_name = 'PEMS04' + adj_path = os.path.join(dataset_path, dataset_name, 'PEMS04.csv') + A, adj = load_adj(args['num_nodes'], adj_path, std=True) + case 883: + dataset_name = 'PEMS07' + adj_path = os.path.join(dataset_path, dataset_name, 'PEMS07.csv') + A, adj = load_adj(args['num_nodes'], adj_path) + case 170: + dataset_name = 'PEMS08' + adj_path = os.path.join(dataset_path, dataset_name, 'PEMS08.csv') + A, adj = load_adj(args['num_nodes'], adj_path, std=True) + + gso = calc_gso(adj, args['gso_type']) + if args['graph_conv_type'] == 'cheb_graph_conv': + gso = calc_chebynet_gso(gso) + gso = gso.toarray() + gso = gso.astype(dtype=np.float32) + gso = torch.from_numpy(gso).to(args['device']) + return gso + +def load_adj(num_nodes, adj_path, id_filename=None, std=False): + ''' + Parameters + ---------- + adj_path: str, path of the csv file contains edges information + num_nodes: int, the number of vertices + id_filename: str, optional, path of the file containing node IDs (if not starting from 0) + std: bool, if True, normalize the cost values in the CSV file using Gaussian normalization + + Returns + ---------- + A: np.ndarray, adjacency matrix + distanceA: np.ndarray, distance matrix (normalized if std=True) + ''' + if 'npy' in adj_path: + adj_mx = np.load(adj_path) + return adj_mx, None + + else: + A = np.zeros((int(num_nodes), int(num_nodes)), dtype=np.float32) + distanceA = np.zeros((int(num_nodes), int(num_nodes)), dtype=np.float32) + + # 如果提供了id_filename,说明节点ID不是从0开始的,需要重新映射 + if id_filename: + with open(id_filename, 'r') as f: + id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))} + + with open(adj_path, 'r') as f: + f.readline() # 略过表头那一行 + reader = csv.reader(f) + costs = [] # 用于收集所有cost值 + for row in reader: + if len(row) != 3: + continue + i, j, distance = int(row[0]), int(row[1]), float(row[2]) + A[id_dict[i], id_dict[j]] = 1 + # 确保距离值为正 + distance = max(distance, 1e-6) + costs.append(distance) # 收集cost值 + distanceA[id_dict[i], id_dict[j]] = distance + + else: # 如果没有提供id_filename,说明节点ID是从0开始的 + with open(adj_path, 'r') as f: + f.readline() # 略过表头那一行 + reader = csv.reader(f) + costs = [] # 用于收集所有cost值 + for row in reader: + if len(row) != 3: + continue + i, j, distance = int(row[0]), int(row[1]), float(row[2]) + A[i, j] = 1 + # 确保距离值为正 + distance = max(distance, 1e-6) + costs.append(distance) # 收集cost值 + distanceA[i, j] = distance + + # 如果std=True,对CSV中的所有cost值进行高斯正态分布标准化 + if std: + mean_cost = np.mean(costs) # 计算cost值的均值 + std_cost = np.std(costs) # 计算cost值的标准差 + for idx in np.ndindex(distanceA.shape): # 遍历矩阵 + if distanceA[idx] > 0: # 只对非零元素进行标准化 + normalized_value = (distanceA[idx] - mean_cost) / std_cost + # 确保标准化后的值为正 + normalized_value = max(normalized_value, 1e-6) + distanceA[idx] = normalized_value + + # 确保矩阵中没有零行 + row_sums = distanceA.sum(axis=1) + zero_rows = np.where(row_sums == 0)[0] + for row in zero_rows: + distanceA[row, :] = 1e-6 # 将零行替换为一个非零的默认值 + + return A, distanceA + + +def calc_gso(dir_adj, gso_type): + n_vertex = dir_adj.shape[0] + + if not sp.issparse(dir_adj): + dir_adj = sp.csc_matrix(dir_adj) + elif dir_adj.format != 'csc': + dir_adj = dir_adj.tocsc() + + id = sp.identity(n_vertex, format='csc') + + # Symmetrizing an adjacency matrix + adj = dir_adj + dir_adj.T.multiply(dir_adj.T > dir_adj) - dir_adj.multiply(dir_adj.T > dir_adj) + # adj = 0.5 * (dir_adj + dir_adj.transpose()) + + if gso_type in ['sym_renorm_adj', 'rw_renorm_adj', 'sym_renorm_lap', 'rw_renorm_lap']: + adj = adj + id + + if gso_type in ['sym_norm_adj', 'sym_renorm_adj', 'sym_norm_lap', 'sym_renorm_lap']: + row_sum = adj.sum(axis=1).A1 + # Check for zero or negative values in row_sum + if np.any(row_sum <= 0): + raise ValueError( + "Row sum contains zero or negative values, which is not allowed for symmetric normalization.") + + row_sum_inv_sqrt = np.power(row_sum, -0.5) + row_sum_inv_sqrt[np.isinf(row_sum_inv_sqrt)] = 0. # Handle inf values + deg_inv_sqrt = sp.diags(row_sum_inv_sqrt, format='csc') + # A_{sym} = D^{-0.5} * A * D^{-0.5} + sym_norm_adj = deg_inv_sqrt.dot(adj).dot(deg_inv_sqrt) + + if gso_type in ['sym_norm_lap', 'sym_renorm_lap']: + sym_norm_lap = id - sym_norm_adj + gso = sym_norm_lap + else: + gso = sym_norm_adj + + elif gso_type in ['rw_norm_adj', 'rw_renorm_adj', 'rw_norm_lap', 'rw_renorm_lap']: + row_sum = np.sum(adj, axis=1).A1 + # Check for zero or negative values in row_sum + if np.any(row_sum <= 0): + raise ValueError( + "Row sum contains zero or negative values, which is not allowed for random walk normalization.") + + row_sum_inv = np.power(row_sum, -1) + row_sum_inv[np.isinf(row_sum_inv)] = 0. # Handle inf values + deg_inv = sp.diags(row_sum_inv, format='csc') + # A_{rw} = D^{-1} * A + rw_norm_adj = deg_inv.dot(adj) + + if gso_type in ['rw_norm_lap', 'rw_renorm_lap']: + rw_norm_lap = id - rw_norm_adj + gso = rw_norm_lap + else: + gso = rw_norm_adj + + else: + raise ValueError(f'{gso_type} is not defined.') + + # Check for nan or inf in the final result + if np.isnan(gso.data).any() or np.isinf(gso.data).any(): + raise ValueError("NaN or Inf detected in the final GSO matrix. Please check the input adjacency matrix.") + + return gso + + +def calc_chebynet_gso(gso): + if sp.issparse(gso) == False: + gso = sp.csc_matrix(gso) + elif gso.format != 'csc': + gso = gso.tocsc() + + id = sp.identity(gso.shape[0], format='csc') + # If you encounter a NotImplementedError, please update your scipy version to 1.10.1 or later. + eigval_max = norm(gso, 2) + + # If the gso is symmetric or random walk normalized Laplacian, + # then the maximum eigenvalue is smaller than or equals to 2. + if eigval_max >= 2: + gso = gso - id + else: + gso = 2 * gso / eigval_max - id + + return gso