import torch import torch.nn as nn import torch.nn.functional as F class SimpleExpert(nn.Module): """ 专家内部只做: 1. 上采样到固定 graph_size 2. 对每个时间步、每个节点用 Linear(input_dim -> hidden_dim) 3. 下采样回原始 selected_node 数量 """ def __init__(self, input_dim, hidden_dim, graph_size): super().__init__() self.graph_size = graph_size self.linear = nn.Linear(input_dim, hidden_dim) def up_sample(self, x, target_size): # x: (B, T, N_sel, D) B, T, N, D = x.shape # 1) 合并 B 和 T,得到 (B*T, N_sel, D) x2 = x.reshape(B * T, N, D) # 2) 转为 (B*T, D, N_sel) 以便做 1D 线性插值 x2 = x2.permute(0, 2, 1) # (B*T, D, N_sel) # 3) 插值到 graph_size x2 = F.interpolate( x2, size=target_size, mode="linear", align_corners=True ) # (B*T, D, graph_size) # 4) 恢复维度 (B*T, graph_size, D) x2 = x2.permute(0, 2, 1) # (B*T, graph_size, D) # 5) 拆回 (B, T, graph_size, D) x_up = x2.reshape(B, T, target_size, D) return x_up def down_sample(self, x, target_size): # x: (B, T, graph_size, H) B, T, G, H = x.shape # 1) 合并 B 和 T x2 = x.reshape(B * T, G, H) # (B*T, graph_size, H) # 2) 转为 (B*T, H, graph_size) x2 = x2.permute(0, 2, 1) # (B*T, H, graph_size) # 3) 插值到 target_size x2 = F.interpolate( x2, size=target_size, mode="linear", align_corners=True ) # (B*T, H, target_size) # 4) 恢复 (B*T, target_size, H) x2 = x2.permute(0, 2, 1) # (B*T, target_size, H) # 5) 拆回 (B, T, target_size, H) x_down = x2.reshape(B, T, target_size, H) return x_down def forward(self, x): # x: (B, T, N_sel, D) x_up = self.up_sample(x, self.graph_size) # (B, T, graph_size, D) out = self.linear(x_up) # (B, T, graph_size, hidden_dim) out_down = self.down_sample(out, x.shape[2]) # (B, T, N_sel, hidden_dim) return out_down class DGCRM_MOE(nn.Module): """ 去掉 DGCRM,用 SimpleExpert 作为专家,输出 (B, T, N, output_dim): - gate: last step -> top_k 专家 - 每个专家:上采样->linear->下采样 - 累加所有专家输出 -> (B, T, N, hidden_dim) - Linear(hidden_dim -> output_dim) -> (B, T, N, output_dim) - 返回 balance_loss 用于正则化 """ def __init__(self, args): super().__init__() self.graph_size = args["graph_size"] self.expert_nums = args["expert_nums"] self.top_k = args["top_k"] self.input_dim = args["input_dim"] self.hidden_dim = args["hidden_dim"] self.output_dim = args["output_dim"] self.num_node = args["num_nodes"] # gate 网络 self.gate_proj = nn.Linear(self.input_dim, self.hidden_dim) self.gate = nn.Linear(self.hidden_dim, self.expert_nums) # SimpleExpert 列表 self.experts = nn.ModuleList( [ SimpleExpert(self.input_dim, self.hidden_dim, self.graph_size) for _ in range(self.expert_nums) ] ) # 最终多步预测头:hidden_dim -> output_dim self.pred = nn.Linear(self.hidden_dim, self.output_dim) def forward(self, x, **kwargs): """ x: (B, T, N, D_total),只取第0通道作为主观测 returns: out: (B, T, N, output_dim) balance_loss: 标量 """ x = x[..., 0:1] # (B, T, N, 1) B, T, N, D = x.shape # 1. 路由 last = x[:, -1, :, :] # (B, N, 1) g = F.relu(self.gate_proj(last)) # (B, N, hidden_dim) logits = self.gate(g) # (B, N, expert_nums) rw = F.softmax(logits, dim=-1) # (B, N, expert_nums) topk_w, topk_idx = torch.topk( rw, self.top_k, -1 ) # (B, N, top_k), 权重可选用 topk_w # 2. 专家处理 expert_out = torch.zeros(B, T, N, self.hidden_dim, device=x.device) balance_loss = 0.0 for i, expert in enumerate(self.experts): mask = topk_idx == i # (B, N, top_k) if not mask.any(): continue # 平均路由概率 balance_loss += (rw[..., i].mean() - 1.0 / self.expert_nums) ** 2 for b in range(B): sel = torch.nonzero(mask[b].any(-1)).squeeze(-1) if sel.numel() == 0: continue seq = x[b : b + 1, :, sel, :] # (1, T, sel, 1) out_seq = expert(seq) # (1, T, sel, hidden_dim) expert_out[b : b + 1, :, sel, :] += out_seq # 3. 预测头 out = self.pred(expert_out) # (B, T, N, output_dim) return out, balance_loss