修复缺库问题

This commit is contained in:
czzhangheng 2025-11-10 11:13:39 +08:00
parent f3480fccdc
commit d0af46ea5f
8 changed files with 228 additions and 9 deletions

View File

@ -3,7 +3,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from model.DCRNN.dcrnn_cell import DCGRUCell from model.DCRNN.dcrnn_cell import DCGRUCell
from data.get_adj import get_adj from utils.get_adj import get_adj
class Seq2SeqAttrs: class Seq2SeqAttrs:

View File

@ -2,7 +2,7 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.nn as nn import torch.nn as nn
from data.get_adj import get_adj from utils.get_adj import get_adj
class gcn_operation(nn.Module): class gcn_operation(nn.Module):

View File

@ -1,7 +1,7 @@
import torch.nn as nn import torch.nn as nn
from model.STGCN import layers from model.STGCN import layers
from data.get_adj import get_gso from utils.get_adj import get_gso
class STGCNChebGraphConv(nn.Module): class STGCNChebGraphConv(nn.Module):

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import init from torch.nn import init
from data.get_adj import get_adj from utils.get_adj import get_adj
import numbers import numbers

View File

@ -1,8 +1,7 @@
import torch
import torch.nn as nn import torch.nn as nn
from model.ST-SSL.models import STSSL from model.ST-SSL.models import STSSL
from model.ST-SSL.layers import STEncoder, MLP from model.ST-SSL.layers
from data.get_adj import get_gso from utils.get_adj import get_gso
class STSSLModel(nn.Module): class STSSLModel(nn.Module):
def __init__(self, args): def __init__(self, args):

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from data.get_adj import get_gso from utils.get_adj import get_gso
class STSSLModel(nn.Module): class STSSLModel(nn.Module):

View File

@ -10,3 +10,5 @@ torchdiffeq
fastdtw fastdtw
notebook notebook
torchcde torchcde
einops
transformers

218
utils/get_adj.py Normal file
View File

@ -0,0 +1,218 @@
import csv
import os
import numpy as np
from scipy.sparse.linalg import norm
import scipy.sparse as sp
import torch
def get_adj(args):
dataset_path = './data'
match args['num_nodes']:
case 358:
dataset_name = 'PEMS03'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS03.csv')
id = os.path.join(dataset_path, dataset_name, 'PEMS03.txt')
A, adj = load_adj(args['num_nodes'], adj_path, id_filename=id)
case 307:
dataset_name = 'PEMS04'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS04.csv')
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
case 883:
dataset_name = 'PEMS07'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS07.csv')
A, adj = load_adj(args['num_nodes'], adj_path)
case 170:
dataset_name = 'PEMS08'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS08.csv')
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
return adj
def get_gso(args):
dataset_path = './data'
match args['num_nodes']:
case 358:
dataset_name = 'PEMS03'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS03.csv')
id = os.path.join(dataset_path, dataset_name, 'PEMS03.txt')
A, adj = load_adj(args['num_nodes'], adj_path, id_filename=id)
case 307:
dataset_name = 'PEMS04'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS04.csv')
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
case 883:
dataset_name = 'PEMS07'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS07.csv')
A, adj = load_adj(args['num_nodes'], adj_path)
case 170:
dataset_name = 'PEMS08'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS08.csv')
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
gso = calc_gso(adj, args['gso_type'])
if args['graph_conv_type'] == 'cheb_graph_conv':
gso = calc_chebynet_gso(gso)
gso = gso.toarray()
gso = gso.astype(dtype=np.float32)
gso = torch.from_numpy(gso).to(args['device'])
return gso
def load_adj(num_nodes, adj_path, id_filename=None, std=False):
'''
Parameters
----------
adj_path: str, path of the csv file contains edges information
num_nodes: int, the number of vertices
id_filename: str, optional, path of the file containing node IDs (if not starting from 0)
std: bool, if True, normalize the cost values in the CSV file using Gaussian normalization
Returns
----------
A: np.ndarray, adjacency matrix
distanceA: np.ndarray, distance matrix (normalized if std=True)
'''
if 'npy' in adj_path:
adj_mx = np.load(adj_path)
return adj_mx, None
else:
A = np.zeros((int(num_nodes), int(num_nodes)), dtype=np.float32)
distanceA = np.zeros((int(num_nodes), int(num_nodes)), dtype=np.float32)
# 如果提供了id_filename说明节点ID不是从0开始的需要重新映射
if id_filename:
with open(id_filename, 'r') as f:
id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))}
with open(adj_path, 'r') as f:
f.readline() # 略过表头那一行
reader = csv.reader(f)
costs = [] # 用于收集所有cost值
for row in reader:
if len(row) != 3:
continue
i, j, distance = int(row[0]), int(row[1]), float(row[2])
A[id_dict[i], id_dict[j]] = 1
# 确保距离值为正
distance = max(distance, 1e-6)
costs.append(distance) # 收集cost值
distanceA[id_dict[i], id_dict[j]] = distance
else: # 如果没有提供id_filename说明节点ID是从0开始的
with open(adj_path, 'r') as f:
f.readline() # 略过表头那一行
reader = csv.reader(f)
costs = [] # 用于收集所有cost值
for row in reader:
if len(row) != 3:
continue
i, j, distance = int(row[0]), int(row[1]), float(row[2])
A[i, j] = 1
# 确保距离值为正
distance = max(distance, 1e-6)
costs.append(distance) # 收集cost值
distanceA[i, j] = distance
# 如果std=True对CSV中的所有cost值进行高斯正态分布标准化
if std:
mean_cost = np.mean(costs) # 计算cost值的均值
std_cost = np.std(costs) # 计算cost值的标准差
for idx in np.ndindex(distanceA.shape): # 遍历矩阵
if distanceA[idx] > 0: # 只对非零元素进行标准化
normalized_value = (distanceA[idx] - mean_cost) / std_cost
# 确保标准化后的值为正
normalized_value = max(normalized_value, 1e-6)
distanceA[idx] = normalized_value
# 确保矩阵中没有零行
row_sums = distanceA.sum(axis=1)
zero_rows = np.where(row_sums == 0)[0]
for row in zero_rows:
distanceA[row, :] = 1e-6 # 将零行替换为一个非零的默认值
return A, distanceA
def calc_gso(dir_adj, gso_type):
n_vertex = dir_adj.shape[0]
if not sp.issparse(dir_adj):
dir_adj = sp.csc_matrix(dir_adj)
elif dir_adj.format != 'csc':
dir_adj = dir_adj.tocsc()
id = sp.identity(n_vertex, format='csc')
# Symmetrizing an adjacency matrix
adj = dir_adj + dir_adj.T.multiply(dir_adj.T > dir_adj) - dir_adj.multiply(dir_adj.T > dir_adj)
# adj = 0.5 * (dir_adj + dir_adj.transpose())
if gso_type in ['sym_renorm_adj', 'rw_renorm_adj', 'sym_renorm_lap', 'rw_renorm_lap']:
adj = adj + id
if gso_type in ['sym_norm_adj', 'sym_renorm_adj', 'sym_norm_lap', 'sym_renorm_lap']:
row_sum = adj.sum(axis=1).A1
# Check for zero or negative values in row_sum
if np.any(row_sum <= 0):
raise ValueError(
"Row sum contains zero or negative values, which is not allowed for symmetric normalization.")
row_sum_inv_sqrt = np.power(row_sum, -0.5)
row_sum_inv_sqrt[np.isinf(row_sum_inv_sqrt)] = 0. # Handle inf values
deg_inv_sqrt = sp.diags(row_sum_inv_sqrt, format='csc')
# A_{sym} = D^{-0.5} * A * D^{-0.5}
sym_norm_adj = deg_inv_sqrt.dot(adj).dot(deg_inv_sqrt)
if gso_type in ['sym_norm_lap', 'sym_renorm_lap']:
sym_norm_lap = id - sym_norm_adj
gso = sym_norm_lap
else:
gso = sym_norm_adj
elif gso_type in ['rw_norm_adj', 'rw_renorm_adj', 'rw_norm_lap', 'rw_renorm_lap']:
row_sum = np.sum(adj, axis=1).A1
# Check for zero or negative values in row_sum
if np.any(row_sum <= 0):
raise ValueError(
"Row sum contains zero or negative values, which is not allowed for random walk normalization.")
row_sum_inv = np.power(row_sum, -1)
row_sum_inv[np.isinf(row_sum_inv)] = 0. # Handle inf values
deg_inv = sp.diags(row_sum_inv, format='csc')
# A_{rw} = D^{-1} * A
rw_norm_adj = deg_inv.dot(adj)
if gso_type in ['rw_norm_lap', 'rw_renorm_lap']:
rw_norm_lap = id - rw_norm_adj
gso = rw_norm_lap
else:
gso = rw_norm_adj
else:
raise ValueError(f'{gso_type} is not defined.')
# Check for nan or inf in the final result
if np.isnan(gso.data).any() or np.isinf(gso.data).any():
raise ValueError("NaN or Inf detected in the final GSO matrix. Please check the input adjacency matrix.")
return gso
def calc_chebynet_gso(gso):
if sp.issparse(gso) == False:
gso = sp.csc_matrix(gso)
elif gso.format != 'csc':
gso = gso.tocsc()
id = sp.identity(gso.shape[0], format='csc')
# If you encounter a NotImplementedError, please update your scipy version to 1.10.1 or later.
eigval_max = norm(gso, 2)
# If the gso is symmetric or random walk normalized Laplacian,
# then the maximum eigenvalue is smaller than or equals to 2.
if eigval_max >= 2:
gso = gso - id
else:
gso = 2 * gso / eigval_max - id
return gso