TrafficWheel/model/D2STGNN/diffusion_block/dif_model.py

129 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
class STLocalizedConv(nn.Module):
def __init__(self, hidden_dim, dy_graph=None, **model_args):
super().__init__()
# gated temporal conv
self.k_s = model_args['k_s']
self.k_t = model_args['k_t']
self.hidden_dim = hidden_dim
# graph conv - 只保留动态图
self.use_dynamic_hidden_graph = dy_graph
# 只考虑动态图
self.support_len = int(dy_graph) if dy_graph is not None else 0
# num_matric = 1 (X_0) + dynamic graphs count
self.num_matric = 1 + self.support_len
self.dropout = nn.Dropout(model_args['dropout'])
self.fc_list_updt = nn.Linear(
self.k_t * hidden_dim, self.k_t * hidden_dim, bias=False)
self.gcn_updt = nn.Linear(
self.hidden_dim*self.num_matric, self.hidden_dim)
# others
self.bn = nn.BatchNorm2d(self.hidden_dim)
self.activation = nn.ReLU()
def gconv(self, support, X_k, X_0):
out = [X_0]
batch_size, seq_len, _, hidden_dim = X_0.shape
for graph in support:
# 确保graph的形状与X_k匹配
if len(graph.shape) == 3: # 动态图,形状为 [B, N, K*N]
# 复制graph以匹配seq_len维度
graph = graph.unsqueeze(1).repeat(1, seq_len, 1, 1) # [B, L, N, K*N]
elif len(graph.shape) == 2: # 静态图,形状为 [N, K*N]
graph = graph.unsqueeze(0).unsqueeze(1).repeat(batch_size, seq_len, 1, 1) # [B, L, N, K*N]
# 确保X_k的形状正确
if X_k.dim() == 4: # [B, L, K*N, D]
# 进行矩阵乘法:[B, L, N, K*N] x [B, L, K*N, D] -> [B, L, N, D]
H_k = torch.matmul(graph, X_k)
else:
H_k = torch.matmul(graph, X_k.unsqueeze(1))
H_k = H_k.squeeze(1)
out.append(H_k)
# 拼接所有结果
out = torch.cat(out, dim=-1)
# 动态调整线性层的输入维度
if out.shape[-1] != self.gcn_updt.in_features:
# 创建新的线性层,匹配当前的输入维度
new_gcn_updt = nn.Linear(out.shape[-1], self.hidden_dim).to(out.device)
# 复制原有参数(如果可能的话)
with torch.no_grad():
min_dim = min(out.shape[-1], self.gcn_updt.in_features)
new_gcn_updt.weight[:, :min_dim] = self.gcn_updt.weight[:, :min_dim]
if new_gcn_updt.bias is not None and self.gcn_updt.bias is not None:
new_gcn_updt.bias = self.gcn_updt.bias
self.gcn_updt = new_gcn_updt
out = self.gcn_updt(out)
out = self.dropout(out)
return out
def get_graph(self, support):
# Only used in static including static hidden graph and predefined graph, but not used for dynamic graph.
if support is None or len(support) == 0:
return []
graph_ordered = []
mask = 1 - torch.eye(support[0].shape[0]).to(support[0].device)
for graph in support:
k_1_order = graph # 1 order
graph_ordered.append(k_1_order * mask)
# e.g., order = 3, k=[2, 3]; order = 2, k=[2]
for k in range(2, self.k_s+1):
k_1_order = torch.matmul(graph, k_1_order)
graph_ordered.append(k_1_order * mask)
# get st localed graph
st_local_graph = []
for graph in graph_ordered:
graph = graph.unsqueeze(-2).expand(-1, self.k_t, -1)
graph = graph.reshape(
graph.shape[0], graph.shape[1] * graph.shape[2])
# [num_nodes, kernel_size x num_nodes]
st_local_graph.append(graph)
# [order, num_nodes, kernel_size x num_nodes]
return st_local_graph
def forward(self, X, dynamic_graph, static_graph=None):
# X: [bs, seq, nodes, feat]
# [bs, seq, num_nodes, ks, num_feat]
X = X.unfold(1, self.k_t, 1).permute(0, 1, 2, 4, 3)
# seq_len is changing
batch_size, seq_len, num_nodes, kernel_size, num_feat = X.shape
# support - 只保留动态图
support = []
if self.use_dynamic_hidden_graph and dynamic_graph:
# k_order is caled in dynamic_graph_constructor component
support = support + dynamic_graph
# parallelize
X = X.reshape(batch_size, seq_len, num_nodes, kernel_size * num_feat)
# batch_size, seq_len, num_nodes, kernel_size * hidden_dim
out = self.fc_list_updt(X)
out = self.activation(out)
out = out.view(batch_size, seq_len, num_nodes, kernel_size, num_feat)
X_0 = torch.mean(out, dim=-2)
# batch_size, seq_len, kernel_size x num_nodes, hidden_dim
X_k = out.transpose(-3, -2).reshape(batch_size,
seq_len, kernel_size*num_nodes, num_feat)
# 如果support为空直接返回X_0
if len(support) == 0:
return X_0
# Nx3N 3NxD -> NxD: batch_size, seq_len, num_nodes, hidden_dim
hidden = self.gconv(support, X_k, X_0)
return hidden