修复缺库问题
This commit is contained in:
parent
f3480fccdc
commit
d0af46ea5f
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from model.DCRNN.dcrnn_cell import DCGRUCell
|
from model.DCRNN.dcrnn_cell import DCGRUCell
|
||||||
from data.get_adj import get_adj
|
from utils.get_adj import get_adj
|
||||||
|
|
||||||
|
|
||||||
class Seq2SeqAttrs:
|
class Seq2SeqAttrs:
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from data.get_adj import get_adj
|
from utils.get_adj import get_adj
|
||||||
|
|
||||||
|
|
||||||
class gcn_operation(nn.Module):
|
class gcn_operation(nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from model.STGCN import layers
|
from model.STGCN import layers
|
||||||
from data.get_adj import get_gso
|
from utils.get_adj import get_gso
|
||||||
|
|
||||||
|
|
||||||
class STGCNChebGraphConv(nn.Module):
|
class STGCNChebGraphConv(nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn import init
|
from torch.nn import init
|
||||||
from data.get_adj import get_adj
|
from utils.get_adj import get_adj
|
||||||
import numbers
|
import numbers
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from model.ST-SSL.models import STSSL
|
from model.ST-SSL.models import STSSL
|
||||||
from model.ST-SSL.layers import STEncoder, MLP
|
from model.ST-SSL.layers
|
||||||
from data.get_adj import get_gso
|
from utils.get_adj import get_gso
|
||||||
|
|
||||||
class STSSLModel(nn.Module):
|
class STSSLModel(nn.Module):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from data.get_adj import get_gso
|
from utils.get_adj import get_gso
|
||||||
|
|
||||||
|
|
||||||
class STSSLModel(nn.Module):
|
class STSSLModel(nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -10,3 +10,5 @@ torchdiffeq
|
||||||
fastdtw
|
fastdtw
|
||||||
notebook
|
notebook
|
||||||
torchcde
|
torchcde
|
||||||
|
einops
|
||||||
|
transformers
|
||||||
|
|
@ -0,0 +1,218 @@
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from scipy.sparse.linalg import norm
|
||||||
|
import scipy.sparse as sp
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def get_adj(args):
|
||||||
|
dataset_path = './data'
|
||||||
|
match args['num_nodes']:
|
||||||
|
case 358:
|
||||||
|
dataset_name = 'PEMS03'
|
||||||
|
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS03.csv')
|
||||||
|
id = os.path.join(dataset_path, dataset_name, 'PEMS03.txt')
|
||||||
|
A, adj = load_adj(args['num_nodes'], adj_path, id_filename=id)
|
||||||
|
case 307:
|
||||||
|
dataset_name = 'PEMS04'
|
||||||
|
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS04.csv')
|
||||||
|
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
|
||||||
|
case 883:
|
||||||
|
dataset_name = 'PEMS07'
|
||||||
|
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS07.csv')
|
||||||
|
A, adj = load_adj(args['num_nodes'], adj_path)
|
||||||
|
case 170:
|
||||||
|
dataset_name = 'PEMS08'
|
||||||
|
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS08.csv')
|
||||||
|
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
|
||||||
|
|
||||||
|
return adj
|
||||||
|
|
||||||
|
def get_gso(args):
|
||||||
|
dataset_path = './data'
|
||||||
|
match args['num_nodes']:
|
||||||
|
case 358:
|
||||||
|
dataset_name = 'PEMS03'
|
||||||
|
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS03.csv')
|
||||||
|
id = os.path.join(dataset_path, dataset_name, 'PEMS03.txt')
|
||||||
|
A, adj = load_adj(args['num_nodes'], adj_path, id_filename=id)
|
||||||
|
case 307:
|
||||||
|
dataset_name = 'PEMS04'
|
||||||
|
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS04.csv')
|
||||||
|
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
|
||||||
|
case 883:
|
||||||
|
dataset_name = 'PEMS07'
|
||||||
|
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS07.csv')
|
||||||
|
A, adj = load_adj(args['num_nodes'], adj_path)
|
||||||
|
case 170:
|
||||||
|
dataset_name = 'PEMS08'
|
||||||
|
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS08.csv')
|
||||||
|
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
|
||||||
|
|
||||||
|
gso = calc_gso(adj, args['gso_type'])
|
||||||
|
if args['graph_conv_type'] == 'cheb_graph_conv':
|
||||||
|
gso = calc_chebynet_gso(gso)
|
||||||
|
gso = gso.toarray()
|
||||||
|
gso = gso.astype(dtype=np.float32)
|
||||||
|
gso = torch.from_numpy(gso).to(args['device'])
|
||||||
|
return gso
|
||||||
|
|
||||||
|
def load_adj(num_nodes, adj_path, id_filename=None, std=False):
|
||||||
|
'''
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
adj_path: str, path of the csv file contains edges information
|
||||||
|
num_nodes: int, the number of vertices
|
||||||
|
id_filename: str, optional, path of the file containing node IDs (if not starting from 0)
|
||||||
|
std: bool, if True, normalize the cost values in the CSV file using Gaussian normalization
|
||||||
|
|
||||||
|
Returns
|
||||||
|
----------
|
||||||
|
A: np.ndarray, adjacency matrix
|
||||||
|
distanceA: np.ndarray, distance matrix (normalized if std=True)
|
||||||
|
'''
|
||||||
|
if 'npy' in adj_path:
|
||||||
|
adj_mx = np.load(adj_path)
|
||||||
|
return adj_mx, None
|
||||||
|
|
||||||
|
else:
|
||||||
|
A = np.zeros((int(num_nodes), int(num_nodes)), dtype=np.float32)
|
||||||
|
distanceA = np.zeros((int(num_nodes), int(num_nodes)), dtype=np.float32)
|
||||||
|
|
||||||
|
# 如果提供了id_filename,说明节点ID不是从0开始的,需要重新映射
|
||||||
|
if id_filename:
|
||||||
|
with open(id_filename, 'r') as f:
|
||||||
|
id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))}
|
||||||
|
|
||||||
|
with open(adj_path, 'r') as f:
|
||||||
|
f.readline() # 略过表头那一行
|
||||||
|
reader = csv.reader(f)
|
||||||
|
costs = [] # 用于收集所有cost值
|
||||||
|
for row in reader:
|
||||||
|
if len(row) != 3:
|
||||||
|
continue
|
||||||
|
i, j, distance = int(row[0]), int(row[1]), float(row[2])
|
||||||
|
A[id_dict[i], id_dict[j]] = 1
|
||||||
|
# 确保距离值为正
|
||||||
|
distance = max(distance, 1e-6)
|
||||||
|
costs.append(distance) # 收集cost值
|
||||||
|
distanceA[id_dict[i], id_dict[j]] = distance
|
||||||
|
|
||||||
|
else: # 如果没有提供id_filename,说明节点ID是从0开始的
|
||||||
|
with open(adj_path, 'r') as f:
|
||||||
|
f.readline() # 略过表头那一行
|
||||||
|
reader = csv.reader(f)
|
||||||
|
costs = [] # 用于收集所有cost值
|
||||||
|
for row in reader:
|
||||||
|
if len(row) != 3:
|
||||||
|
continue
|
||||||
|
i, j, distance = int(row[0]), int(row[1]), float(row[2])
|
||||||
|
A[i, j] = 1
|
||||||
|
# 确保距离值为正
|
||||||
|
distance = max(distance, 1e-6)
|
||||||
|
costs.append(distance) # 收集cost值
|
||||||
|
distanceA[i, j] = distance
|
||||||
|
|
||||||
|
# 如果std=True,对CSV中的所有cost值进行高斯正态分布标准化
|
||||||
|
if std:
|
||||||
|
mean_cost = np.mean(costs) # 计算cost值的均值
|
||||||
|
std_cost = np.std(costs) # 计算cost值的标准差
|
||||||
|
for idx in np.ndindex(distanceA.shape): # 遍历矩阵
|
||||||
|
if distanceA[idx] > 0: # 只对非零元素进行标准化
|
||||||
|
normalized_value = (distanceA[idx] - mean_cost) / std_cost
|
||||||
|
# 确保标准化后的值为正
|
||||||
|
normalized_value = max(normalized_value, 1e-6)
|
||||||
|
distanceA[idx] = normalized_value
|
||||||
|
|
||||||
|
# 确保矩阵中没有零行
|
||||||
|
row_sums = distanceA.sum(axis=1)
|
||||||
|
zero_rows = np.where(row_sums == 0)[0]
|
||||||
|
for row in zero_rows:
|
||||||
|
distanceA[row, :] = 1e-6 # 将零行替换为一个非零的默认值
|
||||||
|
|
||||||
|
return A, distanceA
|
||||||
|
|
||||||
|
|
||||||
|
def calc_gso(dir_adj, gso_type):
|
||||||
|
n_vertex = dir_adj.shape[0]
|
||||||
|
|
||||||
|
if not sp.issparse(dir_adj):
|
||||||
|
dir_adj = sp.csc_matrix(dir_adj)
|
||||||
|
elif dir_adj.format != 'csc':
|
||||||
|
dir_adj = dir_adj.tocsc()
|
||||||
|
|
||||||
|
id = sp.identity(n_vertex, format='csc')
|
||||||
|
|
||||||
|
# Symmetrizing an adjacency matrix
|
||||||
|
adj = dir_adj + dir_adj.T.multiply(dir_adj.T > dir_adj) - dir_adj.multiply(dir_adj.T > dir_adj)
|
||||||
|
# adj = 0.5 * (dir_adj + dir_adj.transpose())
|
||||||
|
|
||||||
|
if gso_type in ['sym_renorm_adj', 'rw_renorm_adj', 'sym_renorm_lap', 'rw_renorm_lap']:
|
||||||
|
adj = adj + id
|
||||||
|
|
||||||
|
if gso_type in ['sym_norm_adj', 'sym_renorm_adj', 'sym_norm_lap', 'sym_renorm_lap']:
|
||||||
|
row_sum = adj.sum(axis=1).A1
|
||||||
|
# Check for zero or negative values in row_sum
|
||||||
|
if np.any(row_sum <= 0):
|
||||||
|
raise ValueError(
|
||||||
|
"Row sum contains zero or negative values, which is not allowed for symmetric normalization.")
|
||||||
|
|
||||||
|
row_sum_inv_sqrt = np.power(row_sum, -0.5)
|
||||||
|
row_sum_inv_sqrt[np.isinf(row_sum_inv_sqrt)] = 0. # Handle inf values
|
||||||
|
deg_inv_sqrt = sp.diags(row_sum_inv_sqrt, format='csc')
|
||||||
|
# A_{sym} = D^{-0.5} * A * D^{-0.5}
|
||||||
|
sym_norm_adj = deg_inv_sqrt.dot(adj).dot(deg_inv_sqrt)
|
||||||
|
|
||||||
|
if gso_type in ['sym_norm_lap', 'sym_renorm_lap']:
|
||||||
|
sym_norm_lap = id - sym_norm_adj
|
||||||
|
gso = sym_norm_lap
|
||||||
|
else:
|
||||||
|
gso = sym_norm_adj
|
||||||
|
|
||||||
|
elif gso_type in ['rw_norm_adj', 'rw_renorm_adj', 'rw_norm_lap', 'rw_renorm_lap']:
|
||||||
|
row_sum = np.sum(adj, axis=1).A1
|
||||||
|
# Check for zero or negative values in row_sum
|
||||||
|
if np.any(row_sum <= 0):
|
||||||
|
raise ValueError(
|
||||||
|
"Row sum contains zero or negative values, which is not allowed for random walk normalization.")
|
||||||
|
|
||||||
|
row_sum_inv = np.power(row_sum, -1)
|
||||||
|
row_sum_inv[np.isinf(row_sum_inv)] = 0. # Handle inf values
|
||||||
|
deg_inv = sp.diags(row_sum_inv, format='csc')
|
||||||
|
# A_{rw} = D^{-1} * A
|
||||||
|
rw_norm_adj = deg_inv.dot(adj)
|
||||||
|
|
||||||
|
if gso_type in ['rw_norm_lap', 'rw_renorm_lap']:
|
||||||
|
rw_norm_lap = id - rw_norm_adj
|
||||||
|
gso = rw_norm_lap
|
||||||
|
else:
|
||||||
|
gso = rw_norm_adj
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f'{gso_type} is not defined.')
|
||||||
|
|
||||||
|
# Check for nan or inf in the final result
|
||||||
|
if np.isnan(gso.data).any() or np.isinf(gso.data).any():
|
||||||
|
raise ValueError("NaN or Inf detected in the final GSO matrix. Please check the input adjacency matrix.")
|
||||||
|
|
||||||
|
return gso
|
||||||
|
|
||||||
|
|
||||||
|
def calc_chebynet_gso(gso):
|
||||||
|
if sp.issparse(gso) == False:
|
||||||
|
gso = sp.csc_matrix(gso)
|
||||||
|
elif gso.format != 'csc':
|
||||||
|
gso = gso.tocsc()
|
||||||
|
|
||||||
|
id = sp.identity(gso.shape[0], format='csc')
|
||||||
|
# If you encounter a NotImplementedError, please update your scipy version to 1.10.1 or later.
|
||||||
|
eigval_max = norm(gso, 2)
|
||||||
|
|
||||||
|
# If the gso is symmetric or random walk normalized Laplacian,
|
||||||
|
# then the maximum eigenvalue is smaller than or equals to 2.
|
||||||
|
if eigval_max >= 2:
|
||||||
|
gso = gso - id
|
||||||
|
else:
|
||||||
|
gso = 2 * gso / eigval_max - id
|
||||||
|
|
||||||
|
return gso
|
||||||
Loading…
Reference in New Issue