Project-I/models/STDEN/STDEN_modules.md

89 lines
4.5 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

### STDEN 模块与执行流(缩进层级表)
模块 | 类/函数 | 输入 (shape) | 输出 (shape)
--- | --- | --- | ---
1 | STDENModel.forward | inputs: (seq_len, batch_size, num_edges x input_dim) | outputs: (horizon, batch_size, num_edges x output_dim); fe: (nfe:int, time:float)
1.1 | Encoder_z0_RNN.forward | (seq_len, batch_size, num_edges x input_dim) | mean: (1, batch_size, num_nodes x latent_dim); std: (1, batch_size, num_nodes x latent_dim)
1.1.1 | utils.sample_standard_gaussian | mu: (n_traj, batch, num_nodes x latent_dim); sigma: 同形状 | z0: (n_traj, batch, num_nodes x latent_dim)
1.2 | DiffeqSolver.forward | first_point: (n_traj, batch, num_nodes x latent_dim); t: (horizon,) | sol_ys: (horizon, n_traj, batch, num_nodes x latent_dim); fe: (nfe:int, time:float)
1.2.1 | ODEFunc.forward | t_local: 标量/1D; y: (B, num_nodes x latent_dim) | dy/dt: (B, num_nodes x latent_dim)
1.3 | Decoder.forward | (horizon, n_traj, batch, num_nodes x latent_dim) | (horizon, batch, num_edges x output_dim)
---
### 细节模块 — Encoder_z0_RNN
步骤 | 操作 | 输入 (shape) | 输出 (shape)
--- | --- | --- | ---
1 | 重塑到边批 | (seq_len, batch, num_edges x input_dim) | (seq_len, batch, num_edges, input_dim)
2 | 合并边到批 | (seq_len, batch, num_edges, input_dim) | (seq_len, batch x num_edges, input_dim)
3 | GRU 序列编码 | 同上 | (seq_len, batch x num_edges, rnn_units)
4 | 取最后时间步 | 同上 | (batch x num_edges, rnn_units)
5 | 还原边维 | (batch x num_edges, rnn_units) | (batch, num_edges, rnn_units)
6 | 转置 + 边→节点映射 | (batch, num_edges, rnn_units) 经 inv_grad | (batch, num_nodes, rnn_units)
7 | 全连接映射到 2x latent | (batch, num_nodes, rnn_units) | (batch, num_nodes, 2 x latent_dim)
8 | 拆分均值/标准差 | 同上 | mean/std: (batch, num_nodes, latent_dim)
9 | 展平并加时间维 | (batch, num_nodes, latent_dim) | (1, batch, num_nodes x latent_dim)
备注inv_grad 来源于 `utils.graph_grad(adj).T` 并做缩放;`hiddens_to_z0` 为两层 MLP + Tanh 后线性映射至 2 x latent_dim。
---
### 细节模块 — 采样utils.sample_standard_gaussian
步骤 | 操作 | 输入 (shape) | 输出 (shape)
--- | --- | --- | ---
1 | 重复到 n_traj | mean/std: (1, batch, N·Z) → 重复 | (n_traj, batch, N·Z)
2 | 重参数化采样 | mu, sigma | z0: (n_traj, batch, N·Z)
其中 N·Z = num_nodes x latent_dim。
---
### 细节模块 — DiffeqSolver含 ODEFunc 调用)
步骤 | 操作 | 输入 (shape) | 输出 (shape)
--- | --- | --- | ---
1 | 合并样本维度 | first_point: (n_traj, batch, N·Z) | (n_traj x batch, N·Z)
2 | ODE 积分 | t: (horizon,), y0 | pred_y: (horizon, n_traj x batch, N·Z)
3 | 还原维度 | 同上 | (horizon, n_traj, batch, N·Z)
4 | 统计代价 | odefunc.nfe, elapsed_time | fe: (nfe:int, time:float)
ODEFunc 默认filter_type="default")为扩散过程:随机游走支持 + 多阶图卷积门控。
---
### 细节模块 — ODEFunc默认扩散过程
步骤 | 操作 | 输入 (shape) | 输出 (shape)
--- | --- | --- | ---
1 | 形状整理 | y: (B, N·Z) → (B, N, Z) | (B, N, Z)
2 | 多阶图卷积 _gconv | (B, N, Z) | (B, N, Z') 按需设置 Z'(通常保持 Z
3 | 门控 θ | _gconv(..., output=latent_dim) → Sigmoid | θ: (B, N·Z)
4 | 生成场 ode_func_net | 堆叠 _gconv + 激活 | f(y): (B, N·Z)
5 | 右端梯度 | - θ ⊙ f(y) | dy/dt: (B, N·Z)
说明:
- 支撑矩阵来自 `utils.calculate_random_walk_matrix(adj)`(正向/反向)并构造稀疏 Chebyshev 递推的多阶通道。
-`filter_type="unkP"`,则使用 `create_net` 的全连接网络在节点域逐点计算梯度。
---
### 细节模块 — Decoder
步骤 | 操作 | 输入 (shape) | 输出 (shape)
--- | --- | --- | ---
1 | 重塑到节点域 | (T, S, B, N·Z) → (T, S, B, N, Z) | (T, S, B, N, Z)
2 | 节点→边映射 | 乘以 graph_grad (N, E) | (T, S, B, Z, E)
3 | 轨迹与通道均值 | 对 S 和 Z 维做均值 | (T, B, E)
4 | 展平到输出维 | 考虑 output_dim通常为 1 | (T, B, E x output_dim)
符号T=horizonS=n_traj_samplesN=num_nodesE=num_edgesZ=latent_dimB=batch。
---
### 备注与约定
- 内部采用边展平后的时序输入:`(seq_len, batch, num_edges x input_dim)`。
- 图算子:`utils.graph_grad(adj)` 形状 `(N, E)``utils.calculate_random_walk_matrix(adj)` 生成随机游走稀疏矩阵用于图卷积。
- 关键超参数(由配置传入):`latent_dim`, `rnn_units`, `gcn_step`, `n_traj_samples`, `ode_method`, `horizon`, `input_dim`, `output_dim`