TrafficWheel/model/STSGCN/get_adj.py

97 lines
3.4 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import numpy as np
import pandas as pd
import torch
def get_adj(args):
dataset_path = './data'
match args['num_nodes']:
case 358:
dataset_name = 'PEMS03'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS03.csv')
id = os.path.join(dataset_path, dataset_name, 'PEMS03.txt')
A = get_adjacency_matrix(adj_path, args['num_nodes'], args['construct_type'], id_filename=id)
case 307:
dataset_name = 'PEMS04'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS04.csv')
A = get_adjacency_matrix(adj_path, args['num_nodes'], args['construct_type'])
case 883:
dataset_name = 'PEMS07'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS07.csv')
A = get_adjacency_matrix(adj_path, args['num_nodes'], args['construct_type'])
case 170:
dataset_name = 'PEMS08'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS08.csv')
A = get_adjacency_matrix(adj_path, args['num_nodes'], args['construct_type'])
local_adj = construct_adj(A, args['strides'])
local_adj = torch.FloatTensor(local_adj)
return local_adj
def get_adjacency_matrix(distance_df_filename, num_of_vertices, type_='connectivity', id_filename=None):
"""
:param distance_df_filename: str, csv边信息文件路径
:param num_of_vertices:int, 节点数量
:param type_:str, {connectivity, distance}
:param id_filename:str 节点信息文件, 有的话需要构建字典
"""
A = np.zeros((int(num_of_vertices), int(num_of_vertices)), dtype=np.float32)
if id_filename:
with open(id_filename, 'r') as f:
id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))} # 建立映射列表
df = pd.read_csv(distance_df_filename)
for row in df.values:
if len(row) != 3:
continue
i, j = int(row[0]), int(row[1])
A[id_dict[i], id_dict[j]] = 1
A[id_dict[j], id_dict[i]] = 1
return A
df = pd.read_csv(distance_df_filename)
for row in df.values:
if len(row) != 3:
continue
i, j, distance = int(row[0]), int(row[1]), float(row[2])
if type_ == 'connectivity':
A[i, j] = 1
A[j, i] = 1
elif type == 'distance':
A[i, j] = 1 / distance
A[j, i] = 1 / distance
else:
raise ValueError("type_ error, must be "
"connectivity or distance!")
return A
def construct_adj(A, steps):
"""
构建local 时空图
:param A: np.ndarray, adjacency matrix, shape is (N, N)
:param steps: 选择几个时间步来构建图
:return: new adjacency matrix: csr_matrix, shape is (N * steps, N * steps)
"""
N = len(A) # 获得行数
adj = np.zeros((N * steps, N * steps))
for i in range(steps):
"""对角线代表各个时间步自己的空间图也就是A"""
adj[i * N: (i + 1) * N, i * N: (i + 1) * N] = A
for i in range(N):
for k in range(steps - 1):
"""每个节点只会连接相邻时间步的自己"""
adj[k * N + i, (k + 1) * N + i] = 1
adj[(k + 1) * N + i, k * N + i] = 1
for i in range(len(adj)):
"""加入自回"""
adj[i, i] = 1
return adj