import torch import torch.nn as nn class STLocalizedConv(nn.Module): def __init__(self, hidden_dim, dy_graph=None, **model_args): super().__init__() # gated temporal conv self.k_s = model_args['k_s'] self.k_t = model_args['k_t'] self.hidden_dim = hidden_dim # graph conv - 只保留动态图 self.use_dynamic_hidden_graph = dy_graph # 只考虑动态图 self.support_len = int(dy_graph) if dy_graph is not None else 0 # num_matric = 1 (X_0) + dynamic graphs count self.num_matric = 1 + self.support_len self.dropout = nn.Dropout(model_args['dropout']) self.fc_list_updt = nn.Linear( self.k_t * hidden_dim, self.k_t * hidden_dim, bias=False) self.gcn_updt = nn.Linear( self.hidden_dim*self.num_matric, self.hidden_dim) # others self.bn = nn.BatchNorm2d(self.hidden_dim) self.activation = nn.ReLU() def gconv(self, support, X_k, X_0): out = [X_0] batch_size, seq_len, _, hidden_dim = X_0.shape for graph in support: # 确保graph的形状与X_k匹配 if len(graph.shape) == 3: # 动态图,形状为 [B, N, K*N] # 复制graph以匹配seq_len维度 graph = graph.unsqueeze(1).repeat(1, seq_len, 1, 1) # [B, L, N, K*N] elif len(graph.shape) == 2: # 静态图,形状为 [N, K*N] graph = graph.unsqueeze(0).unsqueeze(1).repeat(batch_size, seq_len, 1, 1) # [B, L, N, K*N] # 确保X_k的形状正确 if X_k.dim() == 4: # [B, L, K*N, D] # 进行矩阵乘法:[B, L, N, K*N] x [B, L, K*N, D] -> [B, L, N, D] H_k = torch.matmul(graph, X_k) else: H_k = torch.matmul(graph, X_k.unsqueeze(1)) H_k = H_k.squeeze(1) out.append(H_k) # 拼接所有结果 out = torch.cat(out, dim=-1) # 动态调整线性层的输入维度 if out.shape[-1] != self.gcn_updt.in_features: # 创建新的线性层,匹配当前的输入维度 new_gcn_updt = nn.Linear(out.shape[-1], self.hidden_dim).to(out.device) # 复制原有参数(如果可能的话) with torch.no_grad(): min_dim = min(out.shape[-1], self.gcn_updt.in_features) new_gcn_updt.weight[:, :min_dim] = self.gcn_updt.weight[:, :min_dim] if new_gcn_updt.bias is not None and self.gcn_updt.bias is not None: new_gcn_updt.bias = self.gcn_updt.bias self.gcn_updt = new_gcn_updt out = self.gcn_updt(out) out = self.dropout(out) return out def get_graph(self, support): # Only used in static including static hidden graph and predefined graph, but not used for dynamic graph. if support is None or len(support) == 0: return [] graph_ordered = [] mask = 1 - torch.eye(support[0].shape[0]).to(support[0].device) for graph in support: k_1_order = graph # 1 order graph_ordered.append(k_1_order * mask) # e.g., order = 3, k=[2, 3]; order = 2, k=[2] for k in range(2, self.k_s+1): k_1_order = torch.matmul(graph, k_1_order) graph_ordered.append(k_1_order * mask) # get st localed graph st_local_graph = [] for graph in graph_ordered: graph = graph.unsqueeze(-2).expand(-1, self.k_t, -1) graph = graph.reshape( graph.shape[0], graph.shape[1] * graph.shape[2]) # [num_nodes, kernel_size x num_nodes] st_local_graph.append(graph) # [order, num_nodes, kernel_size x num_nodes] return st_local_graph def forward(self, X, dynamic_graph, static_graph=None): # X: [bs, seq, nodes, feat] # [bs, seq, num_nodes, ks, num_feat] X = X.unfold(1, self.k_t, 1).permute(0, 1, 2, 4, 3) # seq_len is changing batch_size, seq_len, num_nodes, kernel_size, num_feat = X.shape # support - 只保留动态图 support = [] if self.use_dynamic_hidden_graph and dynamic_graph: # k_order is caled in dynamic_graph_constructor component support = support + dynamic_graph # parallelize X = X.reshape(batch_size, seq_len, num_nodes, kernel_size * num_feat) # batch_size, seq_len, num_nodes, kernel_size * hidden_dim out = self.fc_list_updt(X) out = self.activation(out) out = out.view(batch_size, seq_len, num_nodes, kernel_size, num_feat) X_0 = torch.mean(out, dim=-2) # batch_size, seq_len, kernel_size x num_nodes, hidden_dim X_k = out.transpose(-3, -2).reshape(batch_size, seq_len, kernel_size*num_nodes, num_feat) # 如果support为空,直接返回X_0 if len(support) == 0: return X_0 # Nx3N 3NxD -> NxD: batch_size, seq_len, num_nodes, hidden_dim hidden = self.gconv(support, X_k, X_0) return hidden