TrafficWheel/model/STGODE/new_adj.py

116 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import pandas as pd
import numpy as np
from fastdtw import fastdtw
from tqdm import tqdm
import torch
from joblib import Parallel, delayed
files = {
358: ('PeMS03/PEMS03.npz', 'PeMS03/PEMS03.csv'),
307: ('PeMS04/PEMS04.npz', 'PEMS04.csv'),
883: ('PeMS07/PEMS07.npz', 'PEMS07.csv'),
170: ('PeMS08/PEMS08.npz', 'PEMS08.csv')
}
def compute_dtw_pair(i, j, data_mean):
return i, j, fastdtw(data_mean[i], data_mean[j], radius=6)[0]
def get_A_hat(args):
"""Optimized version with GPU support and parallel computing"""
# 基础配置
device = torch.device(args['device'])
data_dir = './data/'
num_node = args['num_nodes']
file_npz, file_csv = files[num_node]
dataset_name = file_npz.split('/')[0]
os.makedirs(f'{data_dir}{dataset_name}', exist_ok=True)
# 数据加载与标准化
with np.load(f'{data_dir}{file_npz}') as data:
arr_data = data['data']
arr_data = (arr_data - arr_data.mean((0, 1))) / arr_data.std((0, 1))
arr_data = torch.from_numpy(arr_data).float().to(device)
# DTW矩阵计算带缓存
dtw_path = f'{data_dir}{dataset_name}/{dataset_name}_dtw_distance.npy'
if not os.path.exists(dtw_path):
# 使用GPU加速的均值计算
daily_data = arr_data[..., 0].unfold(0, 288, 288).mean(dim=0).T.cpu().numpy()
# 并行计算DTW
print("Computing DTW matrix with parallel optimization...")
results = Parallel(n_jobs=-1)(
delayed(compute_dtw_pair)(i, j, daily_data)
for i in tqdm(range(num_node)) for j in range(i, num_node)
)
dtw_matrix = np.full((num_node, num_node), np.inf)
for i, j, d in results:
dtw_matrix[i, j] = d
dtw_matrix[j, i] = d
np.save(dtw_path, dtw_matrix)
else:
dtw_matrix = np.load(dtw_path)
# DTW矩阵标准化GPU加速
dtw_tensor = torch.from_numpy(dtw_matrix).to(device)
dtw_normalized = (dtw_tensor - dtw_tensor.mean()) / dtw_tensor.std()
semantic_adj = torch.exp(-dtw_normalized ** 2 / args['sigma1'] ** 2)
semantic_adj = (semantic_adj > args['thres1']).float()
# 空间矩阵计算(带缓存)
spatial_path = f'{data_dir}{dataset_name}/{dataset_name}_spatial_distance.npy'
if not os.path.exists(spatial_path):
# 使用Pandas高效读取
df = pd.read_csv(f'{data_dir}{file_csv}', header=None)
if num_node == 358: # 特殊处理节点ID映射
with open(f'{data_dir}{dataset_name}/{dataset_name}.txt') as f:
node_ids = [int(line.strip()) for line in f]
id_map = {nid: idx for idx, nid in enumerate(node_ids)}
df[0] = df[0].map(id_map)
df[1] = df[1].map(id_map)
# 构建稀疏矩阵
spatial_adj = torch.full((num_node, num_node), float('inf'), device=device)
for row in df.itertuples():
i, j, d = int(row[1]), int(row[2]), float(row[3])
spatial_adj[i, j] = spatial_adj[j, i] = d
spatial_adj = spatial_adj.cpu().numpy()
np.save(spatial_path, spatial_adj)
else:
spatial_adj = np.load(spatial_path)
# 空间矩阵标准化GPU加速
mask = spatial_adj != float('inf')
valid_values = torch.from_numpy(spatial_adj[mask]).to(device)
spatial_normalized = (spatial_adj - valid_values.mean().item()) / valid_values.std().item()
spatial_adj = torch.exp(-torch.tensor(spatial_normalized) ** 2 / args['sigma2'] ** 2)
spatial_adj = (spatial_adj > args['thres2']).float()
# 归一化处理
def normalize_adj(adj):
D = adj.sum(1)
D = torch.clamp(D, min=1e-5)
D_inv_sqrt = 1.0 / torch.sqrt(D)
return 0.8 * (torch.eye(adj.size(0), device=device) +
0.8 * D_inv_sqrt.view(-1, 1) * adj * D_inv_sqrt.view(1, -1))
return (normalize_adj(semantic_adj.to(args['device'])), normalize_adj(spatial_adj.to(args['device'])))
# 测试代码
if __name__ == '__main__':
config = {
'sigma1': 0.1,
'sigma2': 10,
'thres1': 0.6,
'thres2': 0.5,
'device': 'cuda:0' if torch.cuda.is_available() else 'cpu'
}
for nodes in [358, 883, 170]:
args = {'num_nodes': nodes, **config}
get_A_hat(args)