TrafficWheel/model/EXP/trash/EXP1.py

123 lines
4.9 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleExpert(nn.Module):
"""
专家内部只做:
1. 上采样到固定 graph_size
2. 对每个时间步、每个节点用 Linear(input_dim -> hidden_dim)
3. 下采样回原始 selected_node 数量
"""
def __init__(self, input_dim, hidden_dim, graph_size):
super().__init__()
self.graph_size = graph_size
self.linear = nn.Linear(input_dim, hidden_dim)
def up_sample(self, x, target_size):
# x: (B, T, N_sel, D)
B, T, N, D = x.shape
# 1) 合并 B 和 T得到 (B*T, N_sel, D)
x2 = x.reshape(B * T, N, D)
# 2) 转为 (B*T, D, N_sel) 以便做 1D 线性插值
x2 = x2.permute(0, 2, 1) # (B*T, D, N_sel)
# 3) 插值到 graph_size
x2 = F.interpolate(x2, size=target_size, mode='linear', align_corners=True) # (B*T, D, graph_size)
# 4) 恢复维度 (B*T, graph_size, D)
x2 = x2.permute(0, 2, 1) # (B*T, graph_size, D)
# 5) 拆回 (B, T, graph_size, D)
x_up = x2.reshape(B, T, target_size, D)
return x_up
def down_sample(self, x, target_size):
# x: (B, T, graph_size, H)
B, T, G, H = x.shape
# 1) 合并 B 和 T
x2 = x.reshape(B * T, G, H) # (B*T, graph_size, H)
# 2) 转为 (B*T, H, graph_size)
x2 = x2.permute(0, 2, 1) # (B*T, H, graph_size)
# 3) 插值到 target_size
x2 = F.interpolate(x2, size=target_size, mode='linear', align_corners=True) # (B*T, H, target_size)
# 4) 恢复 (B*T, target_size, H)
x2 = x2.permute(0, 2, 1) # (B*T, target_size, H)
# 5) 拆回 (B, T, target_size, H)
x_down = x2.reshape(B, T, target_size, H)
return x_down
def forward(self, x):
# x: (B, T, N_sel, D)
x_up = self.up_sample(x, self.graph_size) # (B, T, graph_size, D)
out = self.linear(x_up) # (B, T, graph_size, hidden_dim)
out_down = self.down_sample(out, x.shape[2]) # (B, T, N_sel, hidden_dim)
return out_down
class DGCRM_MOE(nn.Module):
"""
去掉 DGCRM用 SimpleExpert 作为专家,输出 (B, T, N, output_dim)
- gate: last step -> top_k 专家
- 每个专家:上采样->linear->下采样
- 累加所有专家输出 -> (B, T, N, hidden_dim)
- Linear(hidden_dim -> output_dim) -> (B, T, N, output_dim)
- 返回 balance_loss 用于正则化
"""
def __init__(self, args):
super().__init__()
self.graph_size = args['graph_size']
self.expert_nums = args['expert_nums']
self.top_k = args['top_k']
self.input_dim = args['input_dim']
self.hidden_dim = args['hidden_dim']
self.output_dim = args['output_dim']
self.num_node = args['num_nodes']
# gate 网络
self.gate_proj = nn.Linear(self.input_dim, self.hidden_dim)
self.gate = nn.Linear(self.hidden_dim, self.expert_nums)
# SimpleExpert 列表
self.experts = nn.ModuleList([
SimpleExpert(self.input_dim, self.hidden_dim, self.graph_size)
for _ in range(self.expert_nums)
])
# 最终多步预测头hidden_dim -> output_dim
self.pred = nn.Linear(self.hidden_dim, self.output_dim)
def forward(self, x, **kwargs):
"""
x: (B, T, N, D_total)只取第0通道作为主观测
returns:
out: (B, T, N, output_dim)
balance_loss: 标量
"""
x = x[..., 0:1] # (B, T, N, 1)
B, T, N, D = x.shape
# 1. 路由
last = x[:, -1, :, :] # (B, N, 1)
g = F.relu(self.gate_proj(last)) # (B, N, hidden_dim)
logits = self.gate(g) # (B, N, expert_nums)
rw = F.softmax(logits, dim=-1) # (B, N, expert_nums)
topk_w, topk_idx = torch.topk(rw, self.top_k, -1) # (B, N, top_k), 权重可选用 topk_w
# 2. 专家处理
expert_out = torch.zeros(B, T, N, self.hidden_dim, device=x.device)
balance_loss = 0.0
for i, expert in enumerate(self.experts):
mask = (topk_idx == i) # (B, N, top_k)
if not mask.any(): continue
# 平均路由概率
balance_loss += (rw[..., i].mean() - 1.0/self.expert_nums)**2
for b in range(B):
sel = torch.nonzero(mask[b].any(-1)).squeeze(-1)
if sel.numel()==0: continue
seq = x[b:b+1, :, sel, :] # (1, T, sel, 1)
out_seq = expert(seq) # (1, T, sel, hidden_dim)
expert_out[b:b+1, :, sel, :] += out_seq
# 3. 预测头
out = self.pred(expert_out) # (B, T, N, output_dim)
return out, balance_loss