136 lines
4.8 KiB
Python
Executable File
136 lines
4.8 KiB
Python
Executable File
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
|
||
class SimpleExpert(nn.Module):
|
||
"""
|
||
专家内部只做:
|
||
1. 上采样到固定 graph_size
|
||
2. 对每个时间步、每个节点用 Linear(input_dim -> hidden_dim)
|
||
3. 下采样回原始 selected_node 数量
|
||
"""
|
||
|
||
def __init__(self, input_dim, hidden_dim, graph_size):
|
||
super().__init__()
|
||
self.graph_size = graph_size
|
||
self.linear = nn.Linear(input_dim, hidden_dim)
|
||
|
||
def up_sample(self, x, target_size):
|
||
# x: (B, T, N_sel, D)
|
||
B, T, N, D = x.shape
|
||
# 1) 合并 B 和 T,得到 (B*T, N_sel, D)
|
||
x2 = x.reshape(B * T, N, D)
|
||
# 2) 转为 (B*T, D, N_sel) 以便做 1D 线性插值
|
||
x2 = x2.permute(0, 2, 1) # (B*T, D, N_sel)
|
||
# 3) 插值到 graph_size
|
||
x2 = F.interpolate(
|
||
x2, size=target_size, mode="linear", align_corners=True
|
||
) # (B*T, D, graph_size)
|
||
# 4) 恢复维度 (B*T, graph_size, D)
|
||
x2 = x2.permute(0, 2, 1) # (B*T, graph_size, D)
|
||
# 5) 拆回 (B, T, graph_size, D)
|
||
x_up = x2.reshape(B, T, target_size, D)
|
||
return x_up
|
||
|
||
def down_sample(self, x, target_size):
|
||
# x: (B, T, graph_size, H)
|
||
B, T, G, H = x.shape
|
||
# 1) 合并 B 和 T
|
||
x2 = x.reshape(B * T, G, H) # (B*T, graph_size, H)
|
||
# 2) 转为 (B*T, H, graph_size)
|
||
x2 = x2.permute(0, 2, 1) # (B*T, H, graph_size)
|
||
# 3) 插值到 target_size
|
||
x2 = F.interpolate(
|
||
x2, size=target_size, mode="linear", align_corners=True
|
||
) # (B*T, H, target_size)
|
||
# 4) 恢复 (B*T, target_size, H)
|
||
x2 = x2.permute(0, 2, 1) # (B*T, target_size, H)
|
||
# 5) 拆回 (B, T, target_size, H)
|
||
x_down = x2.reshape(B, T, target_size, H)
|
||
return x_down
|
||
|
||
def forward(self, x):
|
||
# x: (B, T, N_sel, D)
|
||
x_up = self.up_sample(x, self.graph_size) # (B, T, graph_size, D)
|
||
out = self.linear(x_up) # (B, T, graph_size, hidden_dim)
|
||
out_down = self.down_sample(out, x.shape[2]) # (B, T, N_sel, hidden_dim)
|
||
return out_down
|
||
|
||
|
||
class DGCRM_MOE(nn.Module):
|
||
"""
|
||
去掉 DGCRM,用 SimpleExpert 作为专家,输出 (B, T, N, output_dim):
|
||
- gate: last step -> top_k 专家
|
||
- 每个专家:上采样->linear->下采样
|
||
- 累加所有专家输出 -> (B, T, N, hidden_dim)
|
||
- Linear(hidden_dim -> output_dim) -> (B, T, N, output_dim)
|
||
- 返回 balance_loss 用于正则化
|
||
"""
|
||
|
||
def __init__(self, args):
|
||
super().__init__()
|
||
self.graph_size = args["graph_size"]
|
||
self.expert_nums = args["expert_nums"]
|
||
self.top_k = args["top_k"]
|
||
self.input_dim = args["input_dim"]
|
||
self.hidden_dim = args["hidden_dim"]
|
||
self.output_dim = args["output_dim"]
|
||
self.num_node = args["num_nodes"]
|
||
|
||
# gate 网络
|
||
self.gate_proj = nn.Linear(self.input_dim, self.hidden_dim)
|
||
self.gate = nn.Linear(self.hidden_dim, self.expert_nums)
|
||
|
||
# SimpleExpert 列表
|
||
self.experts = nn.ModuleList(
|
||
[
|
||
SimpleExpert(self.input_dim, self.hidden_dim, self.graph_size)
|
||
for _ in range(self.expert_nums)
|
||
]
|
||
)
|
||
|
||
# 最终多步预测头:hidden_dim -> output_dim
|
||
self.pred = nn.Linear(self.hidden_dim, self.output_dim)
|
||
|
||
def forward(self, x, **kwargs):
|
||
"""
|
||
x: (B, T, N, D_total),只取第0通道作为主观测
|
||
returns:
|
||
out: (B, T, N, output_dim)
|
||
balance_loss: 标量
|
||
"""
|
||
x = x[..., 0:1] # (B, T, N, 1)
|
||
B, T, N, D = x.shape
|
||
|
||
# 1. 路由
|
||
last = x[:, -1, :, :] # (B, N, 1)
|
||
g = F.relu(self.gate_proj(last)) # (B, N, hidden_dim)
|
||
logits = self.gate(g) # (B, N, expert_nums)
|
||
rw = F.softmax(logits, dim=-1) # (B, N, expert_nums)
|
||
topk_w, topk_idx = torch.topk(
|
||
rw, self.top_k, -1
|
||
) # (B, N, top_k), 权重可选用 topk_w
|
||
|
||
# 2. 专家处理
|
||
expert_out = torch.zeros(B, T, N, self.hidden_dim, device=x.device)
|
||
balance_loss = 0.0
|
||
for i, expert in enumerate(self.experts):
|
||
mask = topk_idx == i # (B, N, top_k)
|
||
if not mask.any():
|
||
continue
|
||
# 平均路由概率
|
||
balance_loss += (rw[..., i].mean() - 1.0 / self.expert_nums) ** 2
|
||
|
||
for b in range(B):
|
||
sel = torch.nonzero(mask[b].any(-1)).squeeze(-1)
|
||
if sel.numel() == 0:
|
||
continue
|
||
seq = x[b : b + 1, :, sel, :] # (1, T, sel, 1)
|
||
out_seq = expert(seq) # (1, T, sel, hidden_dim)
|
||
expert_out[b : b + 1, :, sel, :] += out_seq
|
||
|
||
# 3. 预测头
|
||
out = self.pred(expert_out) # (B, T, N, output_dim)
|
||
return out, balance_loss
|