129 lines
5.2 KiB
Python
129 lines
5.2 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
|
||
|
||
class STLocalizedConv(nn.Module):
|
||
def __init__(self, hidden_dim, dy_graph=None, **model_args):
|
||
super().__init__()
|
||
# gated temporal conv
|
||
self.k_s = model_args['k_s']
|
||
self.k_t = model_args['k_t']
|
||
self.hidden_dim = hidden_dim
|
||
|
||
# graph conv - 只保留动态图
|
||
self.use_dynamic_hidden_graph = dy_graph
|
||
|
||
# 只考虑动态图
|
||
self.support_len = int(dy_graph) if dy_graph is not None else 0
|
||
|
||
# num_matric = 1 (X_0) + dynamic graphs count
|
||
self.num_matric = 1 + self.support_len
|
||
self.dropout = nn.Dropout(model_args['dropout'])
|
||
|
||
self.fc_list_updt = nn.Linear(
|
||
self.k_t * hidden_dim, self.k_t * hidden_dim, bias=False)
|
||
self.gcn_updt = nn.Linear(
|
||
self.hidden_dim*self.num_matric, self.hidden_dim)
|
||
|
||
# others
|
||
self.bn = nn.BatchNorm2d(self.hidden_dim)
|
||
self.activation = nn.ReLU()
|
||
|
||
def gconv(self, support, X_k, X_0):
|
||
out = [X_0]
|
||
batch_size, seq_len, _, hidden_dim = X_0.shape
|
||
|
||
for graph in support:
|
||
# 确保graph的形状与X_k匹配
|
||
if len(graph.shape) == 3: # 动态图,形状为 [B, N, K*N]
|
||
# 复制graph以匹配seq_len维度
|
||
graph = graph.unsqueeze(1).repeat(1, seq_len, 1, 1) # [B, L, N, K*N]
|
||
elif len(graph.shape) == 2: # 静态图,形状为 [N, K*N]
|
||
graph = graph.unsqueeze(0).unsqueeze(1).repeat(batch_size, seq_len, 1, 1) # [B, L, N, K*N]
|
||
|
||
# 确保X_k的形状正确
|
||
if X_k.dim() == 4: # [B, L, K*N, D]
|
||
# 进行矩阵乘法:[B, L, N, K*N] x [B, L, K*N, D] -> [B, L, N, D]
|
||
H_k = torch.matmul(graph, X_k)
|
||
else:
|
||
H_k = torch.matmul(graph, X_k.unsqueeze(1))
|
||
H_k = H_k.squeeze(1)
|
||
|
||
out.append(H_k)
|
||
|
||
# 拼接所有结果
|
||
out = torch.cat(out, dim=-1)
|
||
|
||
# 动态调整线性层的输入维度
|
||
if out.shape[-1] != self.gcn_updt.in_features:
|
||
# 创建新的线性层,匹配当前的输入维度
|
||
new_gcn_updt = nn.Linear(out.shape[-1], self.hidden_dim).to(out.device)
|
||
# 复制原有参数(如果可能的话)
|
||
with torch.no_grad():
|
||
min_dim = min(out.shape[-1], self.gcn_updt.in_features)
|
||
new_gcn_updt.weight[:, :min_dim] = self.gcn_updt.weight[:, :min_dim]
|
||
if new_gcn_updt.bias is not None and self.gcn_updt.bias is not None:
|
||
new_gcn_updt.bias = self.gcn_updt.bias
|
||
self.gcn_updt = new_gcn_updt
|
||
|
||
out = self.gcn_updt(out)
|
||
out = self.dropout(out)
|
||
return out
|
||
|
||
def get_graph(self, support):
|
||
# Only used in static including static hidden graph and predefined graph, but not used for dynamic graph.
|
||
if support is None or len(support) == 0:
|
||
return []
|
||
|
||
graph_ordered = []
|
||
mask = 1 - torch.eye(support[0].shape[0]).to(support[0].device)
|
||
for graph in support:
|
||
k_1_order = graph # 1 order
|
||
graph_ordered.append(k_1_order * mask)
|
||
# e.g., order = 3, k=[2, 3]; order = 2, k=[2]
|
||
for k in range(2, self.k_s+1):
|
||
k_1_order = torch.matmul(graph, k_1_order)
|
||
graph_ordered.append(k_1_order * mask)
|
||
# get st localed graph
|
||
st_local_graph = []
|
||
for graph in graph_ordered:
|
||
graph = graph.unsqueeze(-2).expand(-1, self.k_t, -1)
|
||
graph = graph.reshape(
|
||
graph.shape[0], graph.shape[1] * graph.shape[2])
|
||
# [num_nodes, kernel_size x num_nodes]
|
||
st_local_graph.append(graph)
|
||
# [order, num_nodes, kernel_size x num_nodes]
|
||
return st_local_graph
|
||
|
||
def forward(self, X, dynamic_graph, static_graph=None):
|
||
# X: [bs, seq, nodes, feat]
|
||
# [bs, seq, num_nodes, ks, num_feat]
|
||
X = X.unfold(1, self.k_t, 1).permute(0, 1, 2, 4, 3)
|
||
# seq_len is changing
|
||
batch_size, seq_len, num_nodes, kernel_size, num_feat = X.shape
|
||
|
||
# support - 只保留动态图
|
||
support = []
|
||
if self.use_dynamic_hidden_graph and dynamic_graph:
|
||
# k_order is caled in dynamic_graph_constructor component
|
||
support = support + dynamic_graph
|
||
|
||
# parallelize
|
||
X = X.reshape(batch_size, seq_len, num_nodes, kernel_size * num_feat)
|
||
# batch_size, seq_len, num_nodes, kernel_size * hidden_dim
|
||
out = self.fc_list_updt(X)
|
||
out = self.activation(out)
|
||
out = out.view(batch_size, seq_len, num_nodes, kernel_size, num_feat)
|
||
X_0 = torch.mean(out, dim=-2)
|
||
# batch_size, seq_len, kernel_size x num_nodes, hidden_dim
|
||
X_k = out.transpose(-3, -2).reshape(batch_size,
|
||
seq_len, kernel_size*num_nodes, num_feat)
|
||
|
||
# 如果support为空,直接返回X_0
|
||
if len(support) == 0:
|
||
return X_0
|
||
|
||
# Nx3N 3NxD -> NxD: batch_size, seq_len, num_nodes, hidden_dim
|
||
hidden = self.gconv(support, X_k, X_0)
|
||
return hidden
|