add STDEN
This commit is contained in:
parent
66a23ffbbb
commit
ab5811425d
|
|
@ -160,4 +160,5 @@ cython_debug/
|
|||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
STDEN/
|
||||
.STDEN/
|
||||
.data/PEMS08/
|
||||
|
|
@ -0,0 +1 @@
|
|||
Subproject commit e50a1ba6d70528b3e684c85f316aed05bb5085f2
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
basic:
|
||||
device: cuda:0
|
||||
dataset: PEMS08
|
||||
model: STDEN
|
||||
mode: train
|
||||
seed: 2025
|
||||
|
||||
data:
|
||||
dataset_dir: data/PEMS08
|
||||
val_batch_size: 32
|
||||
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||
num_nodes: 170
|
||||
batch_size: 32
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
horizon: 24
|
||||
val_ratio: 0.2
|
||||
test_ratio: 0.2
|
||||
tod: False
|
||||
normalizer: std
|
||||
column_wise: False
|
||||
default_graph: True
|
||||
add_time_in_day: True
|
||||
add_day_in_week: True
|
||||
steps_per_day: 24
|
||||
days_per_week: 7
|
||||
|
||||
model:
|
||||
l1_decay: 0
|
||||
seq_len: 12
|
||||
horizon: 12
|
||||
input_dim: 1
|
||||
output_dim: 1
|
||||
latent_dim: 4
|
||||
n_traj_samples: 3
|
||||
ode_method: dopri5
|
||||
odeint_atol: 0.00001
|
||||
odeint_rtol: 0.00001
|
||||
rnn_units: 64
|
||||
num_rnn_layers: 1
|
||||
gcn_step: 2
|
||||
filter_type: default # unkP IncP default
|
||||
recg_type: gru
|
||||
save_latent: false
|
||||
nfe: false
|
||||
|
||||
train:
|
||||
loss: mae
|
||||
batch_size: 64
|
||||
epochs: 100
|
||||
lr_init: 0.003
|
||||
mape_thresh: 0.001
|
||||
mae_thresh: None
|
||||
debug: False
|
||||
output_dim: 1
|
||||
weight_decay: 0
|
||||
lr_decay: False
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: "5,20,40,70"
|
||||
early_stop: True
|
||||
early_stop_patience: 15
|
||||
grad_norm: False
|
||||
max_grad_norm: 5
|
||||
real_value: True
|
||||
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
---
|
||||
log_base_dir: logs/BJ_GM
|
||||
log_level: INFO
|
||||
|
||||
data:
|
||||
batch_size: 32
|
||||
dataset_dir: data/BJ_GM
|
||||
val_batch_size: 32
|
||||
graph_pkl_filename: data/sensor_graph/adj_GM.npy
|
||||
|
||||
model:
|
||||
l1_decay: 0
|
||||
seq_len: 12
|
||||
horizon: 12
|
||||
input_dim: 1
|
||||
output_dim: 1
|
||||
latent_dim: 4
|
||||
n_traj_samples: 3
|
||||
ode_method: dopri5
|
||||
odeint_atol: 0.00001
|
||||
odeint_rtol: 0.00001
|
||||
rnn_units: 64
|
||||
num_rnn_layers: 1
|
||||
gcn_step: 2
|
||||
filter_type: default # unkP IncP default
|
||||
recg_type: gru
|
||||
save_latent: false
|
||||
nfe: false
|
||||
|
||||
train:
|
||||
base_lr: 0.01
|
||||
dropout: 0
|
||||
load: 0
|
||||
epoch: 0
|
||||
epochs: 100
|
||||
epsilon: 1.0e-3
|
||||
lr_decay_ratio: 0.1
|
||||
max_grad_norm: 5
|
||||
min_learning_rate: 2.0e-06
|
||||
optimizer: adam
|
||||
patience: 20
|
||||
steps: [20, 30, 40, 50]
|
||||
test_every_n_epochs: 5
|
||||
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
---
|
||||
log_base_dir: logs/BJ_RM
|
||||
log_level: INFO
|
||||
|
||||
data:
|
||||
batch_size: 32
|
||||
dataset_dir: data/BJ_RM
|
||||
val_batch_size: 32
|
||||
graph_pkl_filename: data/sensor_graph/adj_RM.npy
|
||||
|
||||
model:
|
||||
l1_decay: 0
|
||||
seq_len: 12
|
||||
horizon: 12
|
||||
input_dim: 1
|
||||
output_dim: 1
|
||||
latent_dim: 4
|
||||
n_traj_samples: 3
|
||||
ode_method: dopri5
|
||||
odeint_atol: 0.00001
|
||||
odeint_rtol: 0.00001
|
||||
rnn_units: 64 # for recognition
|
||||
num_rnn_layers: 1
|
||||
gcn_step: 2
|
||||
filter_type: default # unkP IncP default
|
||||
recg_type: gru
|
||||
save_latent: false
|
||||
nfe: false
|
||||
|
||||
train:
|
||||
base_lr: 0.01
|
||||
dropout: 0
|
||||
load: 0 # 0 for not load
|
||||
epoch: 0
|
||||
epochs: 100
|
||||
epsilon: 1.0e-3
|
||||
lr_decay_ratio: 0.1
|
||||
max_grad_norm: 5
|
||||
min_learning_rate: 2.0e-06
|
||||
optimizer: adam
|
||||
patience: 20
|
||||
steps: [20, 30, 40, 50]
|
||||
test_every_n_epochs: 5
|
||||
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
---
|
||||
log_base_dir: logs/BJ_XZ
|
||||
log_level: INFO
|
||||
|
||||
data:
|
||||
batch_size: 32
|
||||
dataset_dir: data/BJ_XZ
|
||||
val_batch_size: 32
|
||||
graph_pkl_filename: data/sensor_graph/adj_XZ.npy
|
||||
|
||||
model:
|
||||
l1_decay: 0
|
||||
seq_len: 12
|
||||
horizon: 12
|
||||
input_dim: 1
|
||||
output_dim: 1
|
||||
latent_dim: 4
|
||||
n_traj_samples: 3
|
||||
ode_method: dopri5
|
||||
odeint_atol: 0.00001
|
||||
odeint_rtol: 0.00001
|
||||
rnn_units: 64
|
||||
num_rnn_layers: 1
|
||||
gcn_step: 2
|
||||
filter_type: default # unkP IncP default
|
||||
recg_type: gru
|
||||
save_latent: false
|
||||
nfe: false
|
||||
|
||||
train:
|
||||
base_lr: 0.01
|
||||
dropout: 0
|
||||
load: 0 # 0 for not load
|
||||
epoch: 0
|
||||
epochs: 100
|
||||
epsilon: 1.0e-3
|
||||
lr_decay_ratio: 0.1
|
||||
max_grad_norm: 5
|
||||
min_learning_rate: 2.0e-06
|
||||
optimizer: adam
|
||||
patience: 20
|
||||
steps: [20, 30, 40, 50]
|
||||
test_every_n_epochs: 5
|
||||
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
### STDEN 模块与执行流(缩进层级表)
|
||||
|
||||
模块 | 类/函数 | 输入 (shape) | 输出 (shape)
|
||||
--- | --- | --- | ---
|
||||
1 | STDENModel.forward | inputs: (seq_len, batch_size, num_edges x input_dim) | outputs: (horizon, batch_size, num_edges x output_dim); fe: (nfe:int, time:float)
|
||||
1.1 | Encoder_z0_RNN.forward | (seq_len, batch_size, num_edges x input_dim) | mean: (1, batch_size, num_nodes x latent_dim); std: (1, batch_size, num_nodes x latent_dim)
|
||||
1.1.1 | utils.sample_standard_gaussian | mu: (n_traj, batch, num_nodes x latent_dim); sigma: 同形状 | z0: (n_traj, batch, num_nodes x latent_dim)
|
||||
1.2 | DiffeqSolver.forward | first_point: (n_traj, batch, num_nodes x latent_dim); t: (horizon,) | sol_ys: (horizon, n_traj, batch, num_nodes x latent_dim); fe: (nfe:int, time:float)
|
||||
1.2.1 | ODEFunc.forward | t_local: 标量/1D; y: (B, num_nodes x latent_dim) | dy/dt: (B, num_nodes x latent_dim)
|
||||
1.3 | Decoder.forward | (horizon, n_traj, batch, num_nodes x latent_dim) | (horizon, batch, num_edges x output_dim)
|
||||
|
||||
---
|
||||
|
||||
### 细节模块 — Encoder_z0_RNN
|
||||
|
||||
步骤 | 操作 | 输入 (shape) | 输出 (shape)
|
||||
--- | --- | --- | ---
|
||||
1 | 重塑到边批 | (seq_len, batch, num_edges x input_dim) | (seq_len, batch, num_edges, input_dim)
|
||||
2 | 合并边到批 | (seq_len, batch, num_edges, input_dim) | (seq_len, batch x num_edges, input_dim)
|
||||
3 | GRU 序列编码 | 同上 | (seq_len, batch x num_edges, rnn_units)
|
||||
4 | 取最后时间步 | 同上 | (batch x num_edges, rnn_units)
|
||||
5 | 还原边维 | (batch x num_edges, rnn_units) | (batch, num_edges, rnn_units)
|
||||
6 | 转置 + 边→节点映射 | (batch, num_edges, rnn_units) 经 inv_grad | (batch, num_nodes, rnn_units)
|
||||
7 | 全连接映射到 2x latent | (batch, num_nodes, rnn_units) | (batch, num_nodes, 2 x latent_dim)
|
||||
8 | 拆分均值/标准差 | 同上 | mean/std: (batch, num_nodes, latent_dim)
|
||||
9 | 展平并加时间维 | (batch, num_nodes, latent_dim) | (1, batch, num_nodes x latent_dim)
|
||||
|
||||
备注:inv_grad 来源于 `utils.graph_grad(adj).T` 并做缩放;`hiddens_to_z0` 为两层 MLP + Tanh 后线性映射至 2 x latent_dim。
|
||||
|
||||
---
|
||||
|
||||
### 细节模块 — 采样(utils.sample_standard_gaussian)
|
||||
|
||||
步骤 | 操作 | 输入 (shape) | 输出 (shape)
|
||||
--- | --- | --- | ---
|
||||
1 | 重复到 n_traj | mean/std: (1, batch, N·Z) → 重复 | (n_traj, batch, N·Z)
|
||||
2 | 重参数化采样 | mu, sigma | z0: (n_traj, batch, N·Z)
|
||||
|
||||
其中 N·Z = num_nodes x latent_dim。
|
||||
|
||||
---
|
||||
|
||||
### 细节模块 — DiffeqSolver(含 ODEFunc 调用)
|
||||
|
||||
步骤 | 操作 | 输入 (shape) | 输出 (shape)
|
||||
--- | --- | --- | ---
|
||||
1 | 合并样本维度 | first_point: (n_traj, batch, N·Z) | (n_traj x batch, N·Z)
|
||||
2 | ODE 积分 | t: (horizon,), y0 | pred_y: (horizon, n_traj x batch, N·Z)
|
||||
3 | 还原维度 | 同上 | (horizon, n_traj, batch, N·Z)
|
||||
4 | 统计代价 | odefunc.nfe, elapsed_time | fe: (nfe:int, time:float)
|
||||
|
||||
ODEFunc 默认(filter_type="default")为扩散过程:随机游走支持 + 多阶图卷积门控。
|
||||
|
||||
---
|
||||
|
||||
### 细节模块 — ODEFunc(默认扩散过程)
|
||||
|
||||
步骤 | 操作 | 输入 (shape) | 输出 (shape)
|
||||
--- | --- | --- | ---
|
||||
1 | 形状整理 | y: (B, N·Z) → (B, N, Z) | (B, N, Z)
|
||||
2 | 多阶图卷积 _gconv | (B, N, Z) | (B, N, Z') 按需设置 Z'(通常保持 Z)
|
||||
3 | 门控 θ | _gconv(..., output=latent_dim) → Sigmoid | θ: (B, N·Z)
|
||||
4 | 生成场 ode_func_net | 堆叠 _gconv + 激活 | f(y): (B, N·Z)
|
||||
5 | 右端梯度 | - θ ⊙ f(y) | dy/dt: (B, N·Z)
|
||||
|
||||
说明:
|
||||
- 支撑矩阵来自 `utils.calculate_random_walk_matrix(adj)`(正向/反向)并构造稀疏 Chebyshev 递推的多阶通道。
|
||||
- 若 `filter_type="unkP"`,则使用 `create_net` 的全连接网络在节点域逐点计算梯度。
|
||||
|
||||
---
|
||||
|
||||
### 细节模块 — Decoder
|
||||
|
||||
步骤 | 操作 | 输入 (shape) | 输出 (shape)
|
||||
--- | --- | --- | ---
|
||||
1 | 重塑到节点域 | (T, S, B, N·Z) → (T, S, B, N, Z) | (T, S, B, N, Z)
|
||||
2 | 节点→边映射 | 乘以 graph_grad (N, E) | (T, S, B, Z, E)
|
||||
3 | 轨迹与通道均值 | 对 S 和 Z 维做均值 | (T, B, E)
|
||||
4 | 展平到输出维 | 考虑 output_dim(通常为 1) | (T, B, E x output_dim)
|
||||
|
||||
符号:T=horizon,S=n_traj_samples,N=num_nodes,E=num_edges,Z=latent_dim,B=batch。
|
||||
|
||||
---
|
||||
|
||||
### 备注与约定
|
||||
- 内部采用边展平后的时序输入:`(seq_len, batch, num_edges x input_dim)`。
|
||||
- 图算子:`utils.graph_grad(adj)` 形状 `(N, E)`;`utils.calculate_random_walk_matrix(adj)` 生成随机游走稀疏矩阵用于图卷积。
|
||||
- 关键超参数(由配置传入):`latent_dim`, `rnn_units`, `gcn_step`, `n_traj_samples`, `ode_method`, `horizon`, `input_dim`, `output_dim`。
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import time
|
||||
|
||||
from torchdiffeq import odeint
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
class DiffeqSolver(nn.Module):
|
||||
def __init__(self, odefunc, method, latent_dim,
|
||||
odeint_rtol = 1e-4, odeint_atol = 1e-5):
|
||||
nn.Module.__init__(self)
|
||||
|
||||
self.ode_method = method
|
||||
self.odefunc = odefunc
|
||||
self.latent_dim = latent_dim
|
||||
|
||||
self.rtol = odeint_rtol
|
||||
self.atol = odeint_atol
|
||||
|
||||
def forward(self, first_point, time_steps_to_pred):
|
||||
"""
|
||||
Decoder the trajectory through the ODE Solver.
|
||||
|
||||
:param time_steps_to_pred: horizon
|
||||
:param first_point: (n_traj_samples, batch_size, num_nodes * latent_dim)
|
||||
:return: pred_y: # shape (horizon, n_traj_samples, batch_size, self.num_nodes * self.output_dim)
|
||||
"""
|
||||
n_traj_samples, batch_size = first_point.size()[0], first_point.size()[1]
|
||||
first_point = first_point.reshape(n_traj_samples * batch_size, -1) # reduce the complexity by merging dimension
|
||||
|
||||
# pred_y shape: (horizon, n_traj_samples * batch_size, num_nodes * latent_dim)
|
||||
start_time = time.time()
|
||||
self.odefunc.nfe = 0
|
||||
pred_y = odeint(self.odefunc,
|
||||
first_point,
|
||||
time_steps_to_pred,
|
||||
rtol=self.rtol,
|
||||
atol=self.atol,
|
||||
method=self.ode_method)
|
||||
time_fe = time.time() - start_time
|
||||
|
||||
# pred_y shape: (horizon, n_traj_samples, batch_size, num_nodes * latent_dim)
|
||||
pred_y = pred_y.reshape(pred_y.size()[0], n_traj_samples, batch_size, -1)
|
||||
# assert(pred_y.size()[1] == n_traj_samples)
|
||||
# assert(pred_y.size()[2] == batch_size)
|
||||
|
||||
return pred_y, (self.odefunc.nfe, time_fe)
|
||||
|
||||
|
|
@ -0,0 +1,165 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from models.STDEN import utils
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
class LayerParams:
|
||||
def __init__(self, rnn_network: nn.Module, layer_type: str):
|
||||
self._rnn_network = rnn_network
|
||||
self._params_dict = {}
|
||||
self._biases_dict = {}
|
||||
self._type = layer_type
|
||||
|
||||
def get_weights(self, shape):
|
||||
if shape not in self._params_dict:
|
||||
nn_param = nn.Parameter(torch.empty(*shape, device=device))
|
||||
nn.init.xavier_normal_(nn_param)
|
||||
self._params_dict[shape] = nn_param
|
||||
self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)),
|
||||
nn_param)
|
||||
return self._params_dict[shape]
|
||||
|
||||
def get_biases(self, length, bias_start=0.0):
|
||||
if length not in self._biases_dict:
|
||||
biases = nn.Parameter(torch.empty(length, device=device))
|
||||
nn.init.constant_(biases, bias_start)
|
||||
self._biases_dict[length] = biases
|
||||
self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)),
|
||||
biases)
|
||||
|
||||
return self._biases_dict[length]
|
||||
|
||||
class ODEFunc(nn.Module):
|
||||
def __init__(self, num_units, latent_dim, adj_mx, gcn_step, num_nodes,
|
||||
gen_layers=1, nonlinearity='tanh', filter_type="default"):
|
||||
"""
|
||||
:param num_units: dimensionality of the hidden layers
|
||||
:param latent_dim: dimensionality used for ODE (input and output). Analog of a continous latent state
|
||||
:param adj_mx:
|
||||
:param gcn_step:
|
||||
:param num_nodes:
|
||||
:param gen_layers: hidden layers in each ode func.
|
||||
:param nonlinearity:
|
||||
:param filter_type: default
|
||||
:param use_gc_for_ru: whether to use Graph convolution to calculate the reset and update gates.
|
||||
"""
|
||||
super(ODEFunc, self).__init__()
|
||||
self._activation = torch.tanh if nonlinearity == 'tanh' else torch.relu
|
||||
|
||||
self._num_nodes = num_nodes
|
||||
self._num_units = num_units # hidden dimension
|
||||
self._latent_dim = latent_dim
|
||||
self._gen_layers = gen_layers
|
||||
self.nfe = 0
|
||||
|
||||
self._filter_type = filter_type
|
||||
if(self._filter_type == "unkP"):
|
||||
ode_func_net = utils.create_net(latent_dim, latent_dim, n_units=num_units)
|
||||
utils.init_network_weights(ode_func_net)
|
||||
self.gradient_net = ode_func_net
|
||||
else:
|
||||
self._gcn_step = gcn_step
|
||||
self._gconv_params = LayerParams(self, 'gconv')
|
||||
self._supports = []
|
||||
supports = []
|
||||
supports.append(utils.calculate_random_walk_matrix(adj_mx).T)
|
||||
supports.append(utils.calculate_random_walk_matrix(adj_mx.T).T)
|
||||
|
||||
for support in supports:
|
||||
self._supports.append(self._build_sparse_matrix(support))
|
||||
|
||||
@staticmethod
|
||||
def _build_sparse_matrix(L):
|
||||
L = L.tocoo()
|
||||
indices = np.column_stack((L.row, L.col))
|
||||
# this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L)
|
||||
indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))]
|
||||
L = torch.sparse_coo_tensor(indices.T, L.data, L.shape, device=device)
|
||||
return L
|
||||
|
||||
def forward(self, t_local, y, backwards = False):
|
||||
"""
|
||||
Perform one step in solving ODE. Given current data point y and current time point t_local, returns gradient dy/dt at this time point
|
||||
|
||||
t_local: current time point
|
||||
y: value at the current time point, shape (B, num_nodes * latent_dim)
|
||||
|
||||
:return
|
||||
- Output: A `2-D` tensor with shape `(B, num_nodes * latent_dim)`.
|
||||
"""
|
||||
self.nfe += 1
|
||||
grad = self.get_ode_gradient_nn(t_local, y)
|
||||
if backwards:
|
||||
grad = -grad
|
||||
return grad
|
||||
|
||||
def get_ode_gradient_nn(self, t_local, inputs):
|
||||
if(self._filter_type == "unkP"):
|
||||
grad = self._fc(inputs)
|
||||
elif (self._filter_type == "IncP"):
|
||||
grad = - self.ode_func_net(inputs)
|
||||
else: # default is diffusion process
|
||||
# theta shape: (B, num_nodes * latent_dim)
|
||||
theta = torch.sigmoid(self._gconv(inputs, self._latent_dim, bias_start=1.0))
|
||||
grad = - theta * self.ode_func_net(inputs)
|
||||
return grad
|
||||
|
||||
def ode_func_net(self, inputs):
|
||||
c = inputs
|
||||
for i in range(self._gen_layers):
|
||||
c = self._gconv(c, self._num_units)
|
||||
c = self._activation(c)
|
||||
c = self._gconv(c, self._latent_dim)
|
||||
c = self._activation(c)
|
||||
return c
|
||||
|
||||
def _fc(self, inputs):
|
||||
batch_size = inputs.size()[0]
|
||||
grad = self.gradient_net(inputs.view(batch_size * self._num_nodes, self._latent_dim))
|
||||
return grad.reshape(batch_size, self._num_nodes * self._latent_dim) # (batch_size, num_nodes, latent_dim)
|
||||
|
||||
@staticmethod
|
||||
def _concat(x, x_):
|
||||
x_ = x_.unsqueeze(0)
|
||||
return torch.cat([x, x_], dim=0)
|
||||
|
||||
def _gconv(self, inputs, output_size, bias_start=0.0):
|
||||
# Reshape input and state to (batch_size, num_nodes, input_dim/state_dim)
|
||||
batch_size = inputs.shape[0]
|
||||
inputs = torch.reshape(inputs, (batch_size, self._num_nodes, -1))
|
||||
# state = torch.reshape(state, (batch_size, self._num_nodes, -1))
|
||||
# inputs_and_state = torch.cat([inputs, state], dim=2)
|
||||
input_size = inputs.size(2)
|
||||
|
||||
x = inputs
|
||||
x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size)
|
||||
x0 = torch.reshape(x0, shape=[self._num_nodes, input_size * batch_size])
|
||||
x = torch.unsqueeze(x0, 0)
|
||||
|
||||
if self._gcn_step == 0:
|
||||
pass
|
||||
else:
|
||||
for support in self._supports:
|
||||
x1 = torch.sparse.mm(support, x0)
|
||||
x = self._concat(x, x1)
|
||||
|
||||
for k in range(2, self._gcn_step + 1):
|
||||
x2 = 2 * torch.sparse.mm(support, x1) - x0
|
||||
x = self._concat(x, x2)
|
||||
x1, x0 = x2, x1
|
||||
|
||||
num_matrices = len(self._supports) * self._gcn_step + 1 # Adds for x itself.
|
||||
x = torch.reshape(x, shape=[num_matrices, self._num_nodes, input_size, batch_size])
|
||||
x = x.permute(3, 1, 2, 0) # (batch_size, num_nodes, input_size, order)
|
||||
x = torch.reshape(x, shape=[batch_size * self._num_nodes, input_size * num_matrices])
|
||||
|
||||
weights = self._gconv_params.get_weights((input_size * num_matrices, output_size))
|
||||
x = torch.matmul(x, weights) # (batch_size * self._num_nodes, output_size)
|
||||
|
||||
biases = self._gconv_params.get_biases(output_size, bias_start)
|
||||
x += biases
|
||||
# Reshape res back to 2D: (batch_size, num_node, state_dim) -> (batch_size, num_node * state_dim)
|
||||
return torch.reshape(x, [batch_size, self._num_nodes * output_size])
|
||||
|
|
@ -0,0 +1,181 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.modules.rnn import GRU
|
||||
from models.STDEN.ode_func import ODEFunc
|
||||
from models.STDEN.diffeq_solver import DiffeqSolver
|
||||
from models.STDEN import utils
|
||||
from data.graph_loader import load_graph
|
||||
|
||||
class EncoderAttrs:
|
||||
"""编码器属性配置类"""
|
||||
def __init__(self, config, adj_mx):
|
||||
self.adj_mx = adj_mx
|
||||
self.num_nodes = adj_mx.shape[0]
|
||||
self.num_edges = (adj_mx > 0.).sum()
|
||||
self.gcn_step = int(config.get('gcn_step', 2))
|
||||
self.filter_type = config.get('filter_type', 'default')
|
||||
self.num_rnn_layers = int(config.get('num_rnn_layers', 1))
|
||||
self.rnn_units = int(config.get('rnn_units'))
|
||||
self.latent_dim = int(config.get('latent_dim', 4))
|
||||
|
||||
|
||||
class STDENModel(nn.Module, EncoderAttrs):
|
||||
"""STDEN主模型:时空微分方程网络"""
|
||||
def __init__(self, config):
|
||||
nn.Module.__init__(self)
|
||||
adj_mx = load_graph(config)
|
||||
EncoderAttrs.__init__(self, config['model'], adj_mx)
|
||||
|
||||
# 识别网络
|
||||
self.encoder_z0 = Encoder_z0_RNN(config['model'], adj_mx)
|
||||
|
||||
model_kwargs = config['model']
|
||||
# ODE求解器配置
|
||||
self.n_traj_samples = int(model_kwargs.get('n_traj_samples', 1))
|
||||
self.ode_method = model_kwargs.get('ode_method', 'dopri5')
|
||||
self.atol = float(model_kwargs.get('odeint_atol', 1e-4))
|
||||
self.rtol = float(model_kwargs.get('odeint_rtol', 1e-3))
|
||||
self.num_gen_layer = int(model_kwargs.get('gen_layers', 1))
|
||||
self.ode_gen_dim = int(model_kwargs.get('gen_dim', 64))
|
||||
|
||||
# 创建ODE函数和求解器
|
||||
odefunc = ODEFunc(
|
||||
self.ode_gen_dim, self.latent_dim, adj_mx,
|
||||
self.gcn_step, self.num_nodes, filter_type=self.filter_type
|
||||
)
|
||||
|
||||
self.diffeq_solver = DiffeqSolver(
|
||||
odefunc, self.ode_method, self.latent_dim,
|
||||
odeint_rtol=self.rtol, odeint_atol=self.atol
|
||||
)
|
||||
|
||||
# 潜在特征保存设置
|
||||
self.save_latent = bool(model_kwargs.get('save_latent', False))
|
||||
self.latent_feat = None
|
||||
|
||||
# 解码器
|
||||
self.horizon = int(model_kwargs.get('horizon', 1))
|
||||
self.out_feat = int(model_kwargs.get('output_dim', 1))
|
||||
self.decoder = Decoder(
|
||||
self.out_feat, adj_mx, self.num_nodes, self.num_edges
|
||||
)
|
||||
|
||||
def forward(self, inputs, labels=None, batches_seen=None):
|
||||
"""
|
||||
seq2seq前向传播
|
||||
:param inputs: (seq_len, batch_size, num_edges * input_dim)
|
||||
:param labels: (horizon, batch_size, num_edges * output_dim)
|
||||
:param batches_seen: 已见批次数量
|
||||
:return: outputs: (horizon, batch_size, num_edges * output_dim)
|
||||
"""
|
||||
# 编码初始潜在状态
|
||||
B, T, N, C = inputs.shape
|
||||
inputs = inputs.view(T, B, N * C)
|
||||
first_point_mu, first_point_std = self.encoder_z0(inputs)
|
||||
|
||||
# 采样轨迹
|
||||
means_z0 = first_point_mu.repeat(self.n_traj_samples, 1, 1)
|
||||
sigma_z0 = first_point_std.repeat(self.n_traj_samples, 1, 1)
|
||||
first_point_enc = utils.sample_standard_gaussian(means_z0, sigma_z0)
|
||||
|
||||
# 时间步预测
|
||||
time_steps_to_predict = torch.arange(start=0, end=self.horizon, step=1).float()
|
||||
time_steps_to_predict = time_steps_to_predict / len(time_steps_to_predict)
|
||||
|
||||
# ODE求解
|
||||
sol_ys, fe = self.diffeq_solver(first_point_enc, time_steps_to_predict)
|
||||
|
||||
if self.save_latent:
|
||||
self.latent_feat = torch.mean(sol_ys.detach(), axis=1)
|
||||
# 解码输出
|
||||
outputs = self.decoder(sol_ys)
|
||||
|
||||
outputs = outputs.view(B, T, N, C)
|
||||
|
||||
return outputs, fe
|
||||
|
||||
|
||||
class Encoder_z0_RNN(nn.Module, EncoderAttrs):
|
||||
"""RNN编码器:将输入序列编码为初始潜在状态"""
|
||||
def __init__(self, config, adj_mx):
|
||||
nn.Module.__init__(self)
|
||||
EncoderAttrs.__init__(self, config, adj_mx)
|
||||
|
||||
self.recg_type = config.get('recg_type', 'gru')
|
||||
self.input_dim = int(config.get('input_dim', 1))
|
||||
|
||||
if self.recg_type == 'gru':
|
||||
self.gru_rnn = GRU(self.input_dim, self.rnn_units)
|
||||
else:
|
||||
raise NotImplementedError("只支持'gru'识别网络")
|
||||
|
||||
# 隐藏状态到z0的映射
|
||||
self.inv_grad = utils.graph_grad(adj_mx).transpose(-2, -1)
|
||||
self.inv_grad[self.inv_grad != 0.] = 0.5
|
||||
|
||||
self.hiddens_to_z0 = nn.Sequential(
|
||||
nn.Linear(self.rnn_units, 50),
|
||||
nn.Tanh(),
|
||||
nn.Linear(50, self.latent_dim * 2)
|
||||
)
|
||||
utils.init_network_weights(self.hiddens_to_z0)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""
|
||||
编码器前向传播
|
||||
:param inputs: (seq_len, batch_size, num_edges * input_dim)
|
||||
:return: mean, std: (1, batch_size, latent_dim)
|
||||
"""
|
||||
seq_len, batch_size = inputs.size(0), inputs.size(1)
|
||||
|
||||
# 重塑输入并处理
|
||||
inputs = inputs.reshape(seq_len, batch_size, self.num_nodes, self.input_dim)
|
||||
inputs = inputs.reshape(seq_len, batch_size * self.num_nodes, self.input_dim)
|
||||
|
||||
# GRU处理
|
||||
outputs, _ = self.gru_rnn(inputs)
|
||||
last_output = outputs[-1]
|
||||
|
||||
# 重塑并转换维度
|
||||
last_output = torch.reshape(last_output, (batch_size, self.num_nodes, -1))
|
||||
last_output = torch.transpose(last_output, (-2, -1))
|
||||
last_output = torch.matmul(last_output, self.inv_grad).transpose(-2, -1)
|
||||
|
||||
# 生成均值和标准差
|
||||
mean, std = utils.split_last_dim(self.hiddens_to_z0(last_output))
|
||||
mean = mean.reshape(batch_size, -1)
|
||||
std = std.reshape(batch_size, -1).abs()
|
||||
|
||||
return mean.unsqueeze(0), std.unsqueeze(0)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""解码器:将潜在状态解码为输出"""
|
||||
def __init__(self, output_dim, adj_mx, num_nodes, num_edges):
|
||||
super(Decoder, self).__init__()
|
||||
self.num_nodes = num_nodes
|
||||
self.num_edges = num_edges
|
||||
self.grap_grad = utils.graph_grad(adj_mx)
|
||||
self.output_dim = output_dim
|
||||
|
||||
def forward(self, inputs):
|
||||
"""
|
||||
:param inputs: (horizon, n_traj_samples, batch_size, num_nodes * latent_dim)
|
||||
:return: outputs: (horizon, batch_size, num_edges * output_dim)
|
||||
"""
|
||||
horizon, n_traj_samples, batch_size = inputs.size()[:3]
|
||||
|
||||
# 重塑输入
|
||||
inputs = inputs.reshape(horizon, n_traj_samples, batch_size, self.num_nodes, -1).transpose(-2, -1)
|
||||
latent_dim = inputs.size(-2)
|
||||
|
||||
# 图梯度变换:从节点到边
|
||||
outputs = torch.matmul(inputs, self.grap_grad)
|
||||
|
||||
# 重塑并平均采样轨迹
|
||||
outputs = outputs.reshape(horizon, n_traj_samples, batch_size, latent_dim, self.num_nodes, self.output_dim)
|
||||
outputs = torch.mean(torch.mean(outputs, axis=3), axis=1)
|
||||
outputs = outputs.reshape(horizon, batch_size, -1)
|
||||
|
||||
return outputs
|
||||
|
||||
|
|
@ -0,0 +1,234 @@
|
|||
import logging
|
||||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
import scipy.sparse as sp
|
||||
import sys
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class DataLoader(object):
|
||||
def __init__(self, xs, ys, batch_size, pad_with_last_sample=True, shuffle=False):
|
||||
"""
|
||||
|
||||
:param xs:
|
||||
:param ys:
|
||||
:param batch_size:
|
||||
:param pad_with_last_sample: pad with the last sample to make number of samples divisible to batch_size.
|
||||
"""
|
||||
self.batch_size = batch_size
|
||||
self.current_ind = 0
|
||||
if pad_with_last_sample:
|
||||
num_padding = (batch_size - (len(xs) % batch_size)) % batch_size
|
||||
x_padding = np.repeat(xs[-1:], num_padding, axis=0)
|
||||
y_padding = np.repeat(ys[-1:], num_padding, axis=0)
|
||||
xs = np.concatenate([xs, x_padding], axis=0)
|
||||
ys = np.concatenate([ys, y_padding], axis=0)
|
||||
self.size = len(xs)
|
||||
self.num_batch = int(self.size // self.batch_size)
|
||||
if shuffle:
|
||||
permutation = np.random.permutation(self.size)
|
||||
xs, ys = xs[permutation], ys[permutation]
|
||||
self.xs = xs
|
||||
self.ys = ys
|
||||
|
||||
def get_iterator(self):
|
||||
self.current_ind = 0
|
||||
|
||||
def _wrapper():
|
||||
while self.current_ind < self.num_batch:
|
||||
start_ind = self.batch_size * self.current_ind
|
||||
end_ind = min(self.size, self.batch_size * (self.current_ind + 1))
|
||||
x_i = self.xs[start_ind: end_ind, ...]
|
||||
y_i = self.ys[start_ind: end_ind, ...]
|
||||
yield (x_i, y_i)
|
||||
self.current_ind += 1
|
||||
|
||||
return _wrapper()
|
||||
|
||||
|
||||
class StandardScaler:
|
||||
"""
|
||||
Standard the input
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def transform(self, data):
|
||||
return (data - self.mean) / self.std
|
||||
|
||||
def inverse_transform(self, data):
|
||||
return (data * self.std) + self.mean
|
||||
|
||||
|
||||
def calculate_random_walk_matrix(adj_mx):
|
||||
adj_mx = sp.coo_matrix(adj_mx)
|
||||
d = np.array(adj_mx.sum(1))
|
||||
d_inv = np.power(d, -1).flatten()
|
||||
d_inv[np.isinf(d_inv)] = 0.
|
||||
d_mat_inv = sp.diags(d_inv)
|
||||
random_walk_mx = d_mat_inv.dot(adj_mx).tocoo()
|
||||
return random_walk_mx
|
||||
|
||||
|
||||
def config_logging(log_dir, log_filename='info.log', level=logging.INFO):
|
||||
# Add file handler and stdout handler
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
# Create the log directory if necessary.
|
||||
try:
|
||||
os.makedirs(log_dir)
|
||||
except OSError:
|
||||
pass
|
||||
file_handler = logging.FileHandler(os.path.join(log_dir, log_filename))
|
||||
file_handler.setFormatter(formatter)
|
||||
file_handler.setLevel(level=level)
|
||||
# Add console handler.
|
||||
console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
console_handler.setLevel(level=level)
|
||||
logging.basicConfig(handlers=[file_handler, console_handler], level=level)
|
||||
|
||||
|
||||
def get_logger(log_dir, name, log_filename='info.log', level=logging.INFO):
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(level)
|
||||
# Add file handler and stdout handler
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
file_handler = logging.FileHandler(os.path.join(log_dir, log_filename))
|
||||
file_handler.setFormatter(formatter)
|
||||
# Add console handler.
|
||||
console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(console_handler)
|
||||
# Add google cloud log handler
|
||||
logger.info('Log directory: %s', log_dir)
|
||||
return logger
|
||||
|
||||
|
||||
def get_log_dir(kwargs):
|
||||
log_dir = kwargs['train'].get('log_dir')
|
||||
if log_dir is None:
|
||||
batch_size = kwargs['data'].get('batch_size')
|
||||
|
||||
filter_type = kwargs['model'].get('filter_type')
|
||||
gcn_step = kwargs['model'].get('gcn_step')
|
||||
horizon = kwargs['model'].get('horizon')
|
||||
latent_dim = kwargs['model'].get('latent_dim')
|
||||
n_traj_samples = kwargs['model'].get('n_traj_samples')
|
||||
ode_method = kwargs['model'].get('ode_method')
|
||||
|
||||
seq_len = kwargs['model'].get('seq_len')
|
||||
rnn_units = kwargs['model'].get('rnn_units')
|
||||
recg_type = kwargs['model'].get('recg_type')
|
||||
|
||||
if filter_type == 'unkP':
|
||||
filter_type_abbr = 'UP'
|
||||
elif filter_type == 'IncP':
|
||||
filter_type_abbr = 'NV'
|
||||
else:
|
||||
filter_type_abbr = 'DF'
|
||||
|
||||
run_id = 'STDEN_%s-%d_%s-%d_L-%d_N-%d_M-%s_bs-%d_%d-%d_%s/' % (
|
||||
recg_type, rnn_units, filter_type_abbr, gcn_step, latent_dim, n_traj_samples, ode_method, batch_size,
|
||||
seq_len, horizon, time.strftime('%m%d%H%M%S'))
|
||||
base_dir = kwargs.get('log_base_dir')
|
||||
log_dir = os.path.join(base_dir, run_id)
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
return log_dir
|
||||
|
||||
|
||||
def load_dataset(dataset_dir, batch_size, val_batch_size=None, **kwargs):
|
||||
if ('BJ' in dataset_dir):
|
||||
data = dict(np.load(os.path.join(dataset_dir, 'flow.npz'))) # convert readonly NpzFile to writable dict Object
|
||||
for category in ['train', 'val', 'test']:
|
||||
data['x_' + category] = data['x_' + category] # [..., :4] # ignore the time index
|
||||
else:
|
||||
data = {}
|
||||
for category in ['train', 'val', 'test']:
|
||||
cat_data = np.load(os.path.join(dataset_dir, category + '.npz'))
|
||||
data['x_' + category] = cat_data['x']
|
||||
data['y_' + category] = cat_data['y']
|
||||
scaler = StandardScaler(mean=data['x_train'].mean(), std=data['x_train'].std())
|
||||
# Data format
|
||||
for category in ['train', 'val', 'test']:
|
||||
data['x_' + category] = scaler.transform(data['x_' + category])
|
||||
data['y_' + category] = scaler.transform(data['y_' + category])
|
||||
data['train_loader'] = DataLoader(data['x_train'], data['y_train'], batch_size, shuffle=True)
|
||||
data['val_loader'] = DataLoader(data['x_val'], data['y_val'], val_batch_size, shuffle=False)
|
||||
data['test_loader'] = DataLoader(data['x_test'], data['y_test'], val_batch_size, shuffle=False)
|
||||
data['scaler'] = scaler
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def load_graph_data(pkl_filename):
|
||||
adj_mx = np.load(pkl_filename)
|
||||
return adj_mx
|
||||
|
||||
|
||||
def graph_grad(adj_mx):
|
||||
"""Fetch the graph gradient operator."""
|
||||
num_nodes = adj_mx.shape[0]
|
||||
|
||||
num_edges = (adj_mx > 0.).sum()
|
||||
grad = torch.zeros(num_nodes, num_edges)
|
||||
e = 0
|
||||
for i in range(num_nodes):
|
||||
for j in range(num_nodes):
|
||||
if adj_mx[i, j] == 0:
|
||||
continue
|
||||
grad[i, e] = 1.
|
||||
grad[j, e] = -1.
|
||||
e += 1
|
||||
return grad
|
||||
|
||||
|
||||
def init_network_weights(net, std=0.1):
|
||||
"""
|
||||
Just for nn.Linear net.
|
||||
"""
|
||||
for m in net.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, mean=0, std=std)
|
||||
nn.init.constant_(m.bias, val=0)
|
||||
|
||||
|
||||
def split_last_dim(data):
|
||||
last_dim = data.size()[-1]
|
||||
last_dim = last_dim // 2
|
||||
|
||||
res = data[..., :last_dim], data[..., last_dim:]
|
||||
return res
|
||||
|
||||
|
||||
def get_device(tensor):
|
||||
device = torch.device("cpu")
|
||||
if tensor.is_cuda:
|
||||
device = tensor.get_device()
|
||||
return device
|
||||
|
||||
|
||||
def sample_standard_gaussian(mu, sigma):
|
||||
device = get_device(mu)
|
||||
|
||||
d = torch.distributions.normal.Normal(torch.Tensor([0.]).to(device), torch.Tensor([1.]).to(device))
|
||||
r = d.sample(mu.size()).squeeze(-1)
|
||||
return r * sigma.float() + mu.float()
|
||||
|
||||
|
||||
def create_net(n_inputs, n_outputs, n_layers=0, n_units=100, nonlinear=nn.Tanh):
|
||||
layers = [nn.Linear(n_inputs, n_units)]
|
||||
for i in range(n_layers):
|
||||
layers.append(nonlinear())
|
||||
layers.append(nn.Linear(n_units, n_units))
|
||||
|
||||
layers.append(nonlinear())
|
||||
layers.append(nn.Linear(n_units, n_outputs))
|
||||
return nn.Sequential(*layers)
|
||||
Loading…
Reference in New Issue