Project-I/models/STDEN/STDEN_modules.md

4.5 KiB
Raw Permalink Blame History

STDEN 模块与执行流(缩进层级表)

模块 类/函数 输入 (shape) 输出 (shape)
1 STDENModel.forward inputs: (seq_len, batch_size, num_edges x input_dim) outputs: (horizon, batch_size, num_edges x output_dim); fe: (nfe:int, time:float)
1.1 Encoder_z0_RNN.forward (seq_len, batch_size, num_edges x input_dim) mean: (1, batch_size, num_nodes x latent_dim); std: (1, batch_size, num_nodes x latent_dim)
1.1.1 utils.sample_standard_gaussian mu: (n_traj, batch, num_nodes x latent_dim); sigma: 同形状 z0: (n_traj, batch, num_nodes x latent_dim)
1.2 DiffeqSolver.forward first_point: (n_traj, batch, num_nodes x latent_dim); t: (horizon,) sol_ys: (horizon, n_traj, batch, num_nodes x latent_dim); fe: (nfe:int, time:float)
1.2.1 ODEFunc.forward t_local: 标量/1D; y: (B, num_nodes x latent_dim) dy/dt: (B, num_nodes x latent_dim)
1.3 Decoder.forward (horizon, n_traj, batch, num_nodes x latent_dim) (horizon, batch, num_edges x output_dim)

细节模块 — Encoder_z0_RNN

步骤 操作 输入 (shape) 输出 (shape)
1 重塑到边批 (seq_len, batch, num_edges x input_dim) (seq_len, batch, num_edges, input_dim)
2 合并边到批 (seq_len, batch, num_edges, input_dim) (seq_len, batch x num_edges, input_dim)
3 GRU 序列编码 同上 (seq_len, batch x num_edges, rnn_units)
4 取最后时间步 同上 (batch x num_edges, rnn_units)
5 还原边维 (batch x num_edges, rnn_units) (batch, num_edges, rnn_units)
6 转置 + 边→节点映射 (batch, num_edges, rnn_units) 经 inv_grad (batch, num_nodes, rnn_units)
7 全连接映射到 2x latent (batch, num_nodes, rnn_units) (batch, num_nodes, 2 x latent_dim)
8 拆分均值/标准差 同上 mean/std: (batch, num_nodes, latent_dim)
9 展平并加时间维 (batch, num_nodes, latent_dim) (1, batch, num_nodes x latent_dim)

备注inv_grad 来源于 utils.graph_grad(adj).T 并做缩放;hiddens_to_z0 为两层 MLP + Tanh 后线性映射至 2 x latent_dim。


细节模块 — 采样utils.sample_standard_gaussian

步骤 操作 输入 (shape) 输出 (shape)
1 重复到 n_traj mean/std: (1, batch, N·Z) → 重复 (n_traj, batch, N·Z)
2 重参数化采样 mu, sigma z0: (n_traj, batch, N·Z)

其中 N·Z = num_nodes x latent_dim。


细节模块 — DiffeqSolver含 ODEFunc 调用)

步骤 操作 输入 (shape) 输出 (shape)
1 合并样本维度 first_point: (n_traj, batch, N·Z) (n_traj x batch, N·Z)
2 ODE 积分 t: (horizon,), y0 pred_y: (horizon, n_traj x batch, N·Z)
3 还原维度 同上 (horizon, n_traj, batch, N·Z)
4 统计代价 odefunc.nfe, elapsed_time fe: (nfe:int, time:float)

ODEFunc 默认filter_type="default")为扩散过程:随机游走支持 + 多阶图卷积门控。


细节模块 — ODEFunc默认扩散过程

步骤 操作 输入 (shape) 输出 (shape)
1 形状整理 y: (B, N·Z) → (B, N, Z) (B, N, Z)
2 多阶图卷积 _gconv (B, N, Z) (B, N, Z') 按需设置 Z'(通常保持 Z
3 门控 θ _gconv(..., output=latent_dim) → Sigmoid θ: (B, N·Z)
4 生成场 ode_func_net 堆叠 _gconv + 激活 f(y): (B, N·Z)
5 右端梯度 - θ ⊙ f(y) dy/dt: (B, N·Z)

说明:

  • 支撑矩阵来自 utils.calculate_random_walk_matrix(adj)(正向/反向)并构造稀疏 Chebyshev 递推的多阶通道。
  • filter_type="unkP",则使用 create_net 的全连接网络在节点域逐点计算梯度。

细节模块 — Decoder

步骤 操作 输入 (shape) 输出 (shape)
1 重塑到节点域 (T, S, B, N·Z) → (T, S, B, N, Z) (T, S, B, N, Z)
2 节点→边映射 乘以 graph_grad (N, E) (T, S, B, Z, E)
3 轨迹与通道均值 对 S 和 Z 维做均值 (T, B, E)
4 展平到输出维 考虑 output_dim通常为 1 (T, B, E x output_dim)

符号T=horizonS=n_traj_samplesN=num_nodesE=num_edgesZ=latent_dimB=batch。


备注与约定

  • 内部采用边展平后的时序输入:(seq_len, batch, num_edges x input_dim)
  • 图算子:utils.graph_grad(adj) 形状 (N, E)utils.calculate_random_walk_matrix(adj) 生成随机游走稀疏矩阵用于图卷积。
  • 关键超参数(由配置传入):latent_dim, rnn_units, gcn_step, n_traj_samples, ode_method, horizon, input_dim, output_dim