diff --git a/.gitignore b/.gitignore index ac0c3ec..8d21011 100644 --- a/.gitignore +++ b/.gitignore @@ -160,4 +160,5 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ -STDEN/ \ No newline at end of file +.STDEN/ +.data/PEMS08/ \ No newline at end of file diff --git a/STDEN b/STDEN new file mode 160000 index 0000000..e50a1ba --- /dev/null +++ b/STDEN @@ -0,0 +1 @@ +Subproject commit e50a1ba6d70528b3e684c85f316aed05bb5085f2 diff --git a/configs/STDEN/PEMS08.yaml b/configs/STDEN/PEMS08.yaml new file mode 100644 index 0000000..14fc617 --- /dev/null +++ b/configs/STDEN/PEMS08.yaml @@ -0,0 +1,65 @@ +basic: + device: cuda:0 + dataset: PEMS08 + model: STDEN + mode: train + seed: 2025 + +data: + dataset_dir: data/PEMS08 + val_batch_size: 32 + graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy + num_nodes: 170 + batch_size: 32 + input_dim: 1 + lag: 24 + horizon: 24 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 24 + days_per_week: 7 + +model: + l1_decay: 0 + seq_len: 12 + horizon: 12 + input_dim: 1 + output_dim: 1 + latent_dim: 4 + n_traj_samples: 3 + ode_method: dopri5 + odeint_atol: 0.00001 + odeint_rtol: 0.00001 + rnn_units: 64 + num_rnn_layers: 1 + gcn_step: 2 + filter_type: default # unkP IncP default + recg_type: gru + save_latent: false + nfe: false + +train: + loss: mae + batch_size: 64 + epochs: 100 + lr_init: 0.003 + mape_thresh: 0.001 + mae_thresh: None + debug: False + output_dim: 1 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + \ No newline at end of file diff --git a/configs/STDEN/stde_gt.yaml b/configs/STDEN/stde_gt.yaml new file mode 100644 index 0000000..1b2ca18 --- /dev/null +++ b/configs/STDEN/stde_gt.yaml @@ -0,0 +1,44 @@ +--- +log_base_dir: logs/BJ_GM +log_level: INFO + +data: + batch_size: 32 + dataset_dir: data/BJ_GM + val_batch_size: 32 + graph_pkl_filename: data/sensor_graph/adj_GM.npy + +model: + l1_decay: 0 + seq_len: 12 + horizon: 12 + input_dim: 1 + output_dim: 1 + latent_dim: 4 + n_traj_samples: 3 + ode_method: dopri5 + odeint_atol: 0.00001 + odeint_rtol: 0.00001 + rnn_units: 64 + num_rnn_layers: 1 + gcn_step: 2 + filter_type: default # unkP IncP default + recg_type: gru + save_latent: false + nfe: false + +train: + base_lr: 0.01 + dropout: 0 + load: 0 + epoch: 0 + epochs: 100 + epsilon: 1.0e-3 + lr_decay_ratio: 0.1 + max_grad_norm: 5 + min_learning_rate: 2.0e-06 + optimizer: adam + patience: 20 + steps: [20, 30, 40, 50] + test_every_n_epochs: 5 + \ No newline at end of file diff --git a/configs/STDEN/stde_wrs.yaml b/configs/STDEN/stde_wrs.yaml new file mode 100644 index 0000000..8402736 --- /dev/null +++ b/configs/STDEN/stde_wrs.yaml @@ -0,0 +1,44 @@ +--- +log_base_dir: logs/BJ_RM +log_level: INFO + +data: + batch_size: 32 + dataset_dir: data/BJ_RM + val_batch_size: 32 + graph_pkl_filename: data/sensor_graph/adj_RM.npy + +model: + l1_decay: 0 + seq_len: 12 + horizon: 12 + input_dim: 1 + output_dim: 1 + latent_dim: 4 + n_traj_samples: 3 + ode_method: dopri5 + odeint_atol: 0.00001 + odeint_rtol: 0.00001 + rnn_units: 64 # for recognition + num_rnn_layers: 1 + gcn_step: 2 + filter_type: default # unkP IncP default + recg_type: gru + save_latent: false + nfe: false + +train: + base_lr: 0.01 + dropout: 0 + load: 0 # 0 for not load + epoch: 0 + epochs: 100 + epsilon: 1.0e-3 + lr_decay_ratio: 0.1 + max_grad_norm: 5 + min_learning_rate: 2.0e-06 + optimizer: adam + patience: 20 + steps: [20, 30, 40, 50] + test_every_n_epochs: 5 + \ No newline at end of file diff --git a/configs/STDEN/stde_zgc.yaml b/configs/STDEN/stde_zgc.yaml new file mode 100644 index 0000000..f291538 --- /dev/null +++ b/configs/STDEN/stde_zgc.yaml @@ -0,0 +1,44 @@ +--- +log_base_dir: logs/BJ_XZ +log_level: INFO + +data: + batch_size: 32 + dataset_dir: data/BJ_XZ + val_batch_size: 32 + graph_pkl_filename: data/sensor_graph/adj_XZ.npy + +model: + l1_decay: 0 + seq_len: 12 + horizon: 12 + input_dim: 1 + output_dim: 1 + latent_dim: 4 + n_traj_samples: 3 + ode_method: dopri5 + odeint_atol: 0.00001 + odeint_rtol: 0.00001 + rnn_units: 64 + num_rnn_layers: 1 + gcn_step: 2 + filter_type: default # unkP IncP default + recg_type: gru + save_latent: false + nfe: false + +train: + base_lr: 0.01 + dropout: 0 + load: 0 # 0 for not load + epoch: 0 + epochs: 100 + epsilon: 1.0e-3 + lr_decay_ratio: 0.1 + max_grad_norm: 5 + min_learning_rate: 2.0e-06 + optimizer: adam + patience: 20 + steps: [20, 30, 40, 50] + test_every_n_epochs: 5 + \ No newline at end of file diff --git a/models/STDEN/STDEN_modules.md b/models/STDEN/STDEN_modules.md new file mode 100644 index 0000000..09cdc63 --- /dev/null +++ b/models/STDEN/STDEN_modules.md @@ -0,0 +1,88 @@ +### STDEN 模块与执行流(缩进层级表) + +模块 | 类/函数 | 输入 (shape) | 输出 (shape) +--- | --- | --- | --- +1 | STDENModel.forward | inputs: (seq_len, batch_size, num_edges x input_dim) | outputs: (horizon, batch_size, num_edges x output_dim); fe: (nfe:int, time:float) +1.1 | Encoder_z0_RNN.forward | (seq_len, batch_size, num_edges x input_dim) | mean: (1, batch_size, num_nodes x latent_dim); std: (1, batch_size, num_nodes x latent_dim) +1.1.1 | utils.sample_standard_gaussian | mu: (n_traj, batch, num_nodes x latent_dim); sigma: 同形状 | z0: (n_traj, batch, num_nodes x latent_dim) +1.2 | DiffeqSolver.forward | first_point: (n_traj, batch, num_nodes x latent_dim); t: (horizon,) | sol_ys: (horizon, n_traj, batch, num_nodes x latent_dim); fe: (nfe:int, time:float) +1.2.1 | ODEFunc.forward | t_local: 标量/1D; y: (B, num_nodes x latent_dim) | dy/dt: (B, num_nodes x latent_dim) +1.3 | Decoder.forward | (horizon, n_traj, batch, num_nodes x latent_dim) | (horizon, batch, num_edges x output_dim) + +--- + +### 细节模块 — Encoder_z0_RNN + +步骤 | 操作 | 输入 (shape) | 输出 (shape) +--- | --- | --- | --- +1 | 重塑到边批 | (seq_len, batch, num_edges x input_dim) | (seq_len, batch, num_edges, input_dim) +2 | 合并边到批 | (seq_len, batch, num_edges, input_dim) | (seq_len, batch x num_edges, input_dim) +3 | GRU 序列编码 | 同上 | (seq_len, batch x num_edges, rnn_units) +4 | 取最后时间步 | 同上 | (batch x num_edges, rnn_units) +5 | 还原边维 | (batch x num_edges, rnn_units) | (batch, num_edges, rnn_units) +6 | 转置 + 边→节点映射 | (batch, num_edges, rnn_units) 经 inv_grad | (batch, num_nodes, rnn_units) +7 | 全连接映射到 2x latent | (batch, num_nodes, rnn_units) | (batch, num_nodes, 2 x latent_dim) +8 | 拆分均值/标准差 | 同上 | mean/std: (batch, num_nodes, latent_dim) +9 | 展平并加时间维 | (batch, num_nodes, latent_dim) | (1, batch, num_nodes x latent_dim) + +备注:inv_grad 来源于 `utils.graph_grad(adj).T` 并做缩放;`hiddens_to_z0` 为两层 MLP + Tanh 后线性映射至 2 x latent_dim。 + +--- + +### 细节模块 — 采样(utils.sample_standard_gaussian) + +步骤 | 操作 | 输入 (shape) | 输出 (shape) +--- | --- | --- | --- +1 | 重复到 n_traj | mean/std: (1, batch, N·Z) → 重复 | (n_traj, batch, N·Z) +2 | 重参数化采样 | mu, sigma | z0: (n_traj, batch, N·Z) + +其中 N·Z = num_nodes x latent_dim。 + +--- + +### 细节模块 — DiffeqSolver(含 ODEFunc 调用) + +步骤 | 操作 | 输入 (shape) | 输出 (shape) +--- | --- | --- | --- +1 | 合并样本维度 | first_point: (n_traj, batch, N·Z) | (n_traj x batch, N·Z) +2 | ODE 积分 | t: (horizon,), y0 | pred_y: (horizon, n_traj x batch, N·Z) +3 | 还原维度 | 同上 | (horizon, n_traj, batch, N·Z) +4 | 统计代价 | odefunc.nfe, elapsed_time | fe: (nfe:int, time:float) + +ODEFunc 默认(filter_type="default")为扩散过程:随机游走支持 + 多阶图卷积门控。 + +--- + +### 细节模块 — ODEFunc(默认扩散过程) + +步骤 | 操作 | 输入 (shape) | 输出 (shape) +--- | --- | --- | --- +1 | 形状整理 | y: (B, N·Z) → (B, N, Z) | (B, N, Z) +2 | 多阶图卷积 _gconv | (B, N, Z) | (B, N, Z') 按需设置 Z'(通常保持 Z) +3 | 门控 θ | _gconv(..., output=latent_dim) → Sigmoid | θ: (B, N·Z) +4 | 生成场 ode_func_net | 堆叠 _gconv + 激活 | f(y): (B, N·Z) +5 | 右端梯度 | - θ ⊙ f(y) | dy/dt: (B, N·Z) + +说明: +- 支撑矩阵来自 `utils.calculate_random_walk_matrix(adj)`(正向/反向)并构造稀疏 Chebyshev 递推的多阶通道。 +- 若 `filter_type="unkP"`,则使用 `create_net` 的全连接网络在节点域逐点计算梯度。 + +--- + +### 细节模块 — Decoder + +步骤 | 操作 | 输入 (shape) | 输出 (shape) +--- | --- | --- | --- +1 | 重塑到节点域 | (T, S, B, N·Z) → (T, S, B, N, Z) | (T, S, B, N, Z) +2 | 节点→边映射 | 乘以 graph_grad (N, E) | (T, S, B, Z, E) +3 | 轨迹与通道均值 | 对 S 和 Z 维做均值 | (T, B, E) +4 | 展平到输出维 | 考虑 output_dim(通常为 1) | (T, B, E x output_dim) + +符号:T=horizon,S=n_traj_samples,N=num_nodes,E=num_edges,Z=latent_dim,B=batch。 + +--- + +### 备注与约定 +- 内部采用边展平后的时序输入:`(seq_len, batch, num_edges x input_dim)`。 +- 图算子:`utils.graph_grad(adj)` 形状 `(N, E)`;`utils.calculate_random_walk_matrix(adj)` 生成随机游走稀疏矩阵用于图卷积。 +- 关键超参数(由配置传入):`latent_dim`, `rnn_units`, `gcn_step`, `n_traj_samples`, `ode_method`, `horizon`, `input_dim`, `output_dim`。 diff --git a/models/STDEN/diffeq_solver.py b/models/STDEN/diffeq_solver.py new file mode 100644 index 0000000..dfd0b17 --- /dev/null +++ b/models/STDEN/diffeq_solver.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn +import time + +from torchdiffeq import odeint + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +class DiffeqSolver(nn.Module): + def __init__(self, odefunc, method, latent_dim, + odeint_rtol = 1e-4, odeint_atol = 1e-5): + nn.Module.__init__(self) + + self.ode_method = method + self.odefunc = odefunc + self.latent_dim = latent_dim + + self.rtol = odeint_rtol + self.atol = odeint_atol + + def forward(self, first_point, time_steps_to_pred): + """ + Decoder the trajectory through the ODE Solver. + + :param time_steps_to_pred: horizon + :param first_point: (n_traj_samples, batch_size, num_nodes * latent_dim) + :return: pred_y: # shape (horizon, n_traj_samples, batch_size, self.num_nodes * self.output_dim) + """ + n_traj_samples, batch_size = first_point.size()[0], first_point.size()[1] + first_point = first_point.reshape(n_traj_samples * batch_size, -1) # reduce the complexity by merging dimension + + # pred_y shape: (horizon, n_traj_samples * batch_size, num_nodes * latent_dim) + start_time = time.time() + self.odefunc.nfe = 0 + pred_y = odeint(self.odefunc, + first_point, + time_steps_to_pred, + rtol=self.rtol, + atol=self.atol, + method=self.ode_method) + time_fe = time.time() - start_time + + # pred_y shape: (horizon, n_traj_samples, batch_size, num_nodes * latent_dim) + pred_y = pred_y.reshape(pred_y.size()[0], n_traj_samples, batch_size, -1) + # assert(pred_y.size()[1] == n_traj_samples) + # assert(pred_y.size()[2] == batch_size) + + return pred_y, (self.odefunc.nfe, time_fe) + \ No newline at end of file diff --git a/models/STDEN/ode_func.py b/models/STDEN/ode_func.py new file mode 100644 index 0000000..10456e3 --- /dev/null +++ b/models/STDEN/ode_func.py @@ -0,0 +1,165 @@ +import numpy as np +import torch +import torch.nn as nn + +from models.STDEN import utils + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +class LayerParams: + def __init__(self, rnn_network: nn.Module, layer_type: str): + self._rnn_network = rnn_network + self._params_dict = {} + self._biases_dict = {} + self._type = layer_type + + def get_weights(self, shape): + if shape not in self._params_dict: + nn_param = nn.Parameter(torch.empty(*shape, device=device)) + nn.init.xavier_normal_(nn_param) + self._params_dict[shape] = nn_param + self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)), + nn_param) + return self._params_dict[shape] + + def get_biases(self, length, bias_start=0.0): + if length not in self._biases_dict: + biases = nn.Parameter(torch.empty(length, device=device)) + nn.init.constant_(biases, bias_start) + self._biases_dict[length] = biases + self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)), + biases) + + return self._biases_dict[length] + +class ODEFunc(nn.Module): + def __init__(self, num_units, latent_dim, adj_mx, gcn_step, num_nodes, + gen_layers=1, nonlinearity='tanh', filter_type="default"): + """ + :param num_units: dimensionality of the hidden layers + :param latent_dim: dimensionality used for ODE (input and output). Analog of a continous latent state + :param adj_mx: + :param gcn_step: + :param num_nodes: + :param gen_layers: hidden layers in each ode func. + :param nonlinearity: + :param filter_type: default + :param use_gc_for_ru: whether to use Graph convolution to calculate the reset and update gates. + """ + super(ODEFunc, self).__init__() + self._activation = torch.tanh if nonlinearity == 'tanh' else torch.relu + + self._num_nodes = num_nodes + self._num_units = num_units # hidden dimension + self._latent_dim = latent_dim + self._gen_layers = gen_layers + self.nfe = 0 + + self._filter_type = filter_type + if(self._filter_type == "unkP"): + ode_func_net = utils.create_net(latent_dim, latent_dim, n_units=num_units) + utils.init_network_weights(ode_func_net) + self.gradient_net = ode_func_net + else: + self._gcn_step = gcn_step + self._gconv_params = LayerParams(self, 'gconv') + self._supports = [] + supports = [] + supports.append(utils.calculate_random_walk_matrix(adj_mx).T) + supports.append(utils.calculate_random_walk_matrix(adj_mx.T).T) + + for support in supports: + self._supports.append(self._build_sparse_matrix(support)) + + @staticmethod + def _build_sparse_matrix(L): + L = L.tocoo() + indices = np.column_stack((L.row, L.col)) + # this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L) + indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))] + L = torch.sparse_coo_tensor(indices.T, L.data, L.shape, device=device) + return L + + def forward(self, t_local, y, backwards = False): + """ + Perform one step in solving ODE. Given current data point y and current time point t_local, returns gradient dy/dt at this time point + + t_local: current time point + y: value at the current time point, shape (B, num_nodes * latent_dim) + + :return + - Output: A `2-D` tensor with shape `(B, num_nodes * latent_dim)`. + """ + self.nfe += 1 + grad = self.get_ode_gradient_nn(t_local, y) + if backwards: + grad = -grad + return grad + + def get_ode_gradient_nn(self, t_local, inputs): + if(self._filter_type == "unkP"): + grad = self._fc(inputs) + elif (self._filter_type == "IncP"): + grad = - self.ode_func_net(inputs) + else: # default is diffusion process + # theta shape: (B, num_nodes * latent_dim) + theta = torch.sigmoid(self._gconv(inputs, self._latent_dim, bias_start=1.0)) + grad = - theta * self.ode_func_net(inputs) + return grad + + def ode_func_net(self, inputs): + c = inputs + for i in range(self._gen_layers): + c = self._gconv(c, self._num_units) + c = self._activation(c) + c = self._gconv(c, self._latent_dim) + c = self._activation(c) + return c + + def _fc(self, inputs): + batch_size = inputs.size()[0] + grad = self.gradient_net(inputs.view(batch_size * self._num_nodes, self._latent_dim)) + return grad.reshape(batch_size, self._num_nodes * self._latent_dim) # (batch_size, num_nodes, latent_dim) + + @staticmethod + def _concat(x, x_): + x_ = x_.unsqueeze(0) + return torch.cat([x, x_], dim=0) + + def _gconv(self, inputs, output_size, bias_start=0.0): + # Reshape input and state to (batch_size, num_nodes, input_dim/state_dim) + batch_size = inputs.shape[0] + inputs = torch.reshape(inputs, (batch_size, self._num_nodes, -1)) + # state = torch.reshape(state, (batch_size, self._num_nodes, -1)) + # inputs_and_state = torch.cat([inputs, state], dim=2) + input_size = inputs.size(2) + + x = inputs + x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size) + x0 = torch.reshape(x0, shape=[self._num_nodes, input_size * batch_size]) + x = torch.unsqueeze(x0, 0) + + if self._gcn_step == 0: + pass + else: + for support in self._supports: + x1 = torch.sparse.mm(support, x0) + x = self._concat(x, x1) + + for k in range(2, self._gcn_step + 1): + x2 = 2 * torch.sparse.mm(support, x1) - x0 + x = self._concat(x, x2) + x1, x0 = x2, x1 + + num_matrices = len(self._supports) * self._gcn_step + 1 # Adds for x itself. + x = torch.reshape(x, shape=[num_matrices, self._num_nodes, input_size, batch_size]) + x = x.permute(3, 1, 2, 0) # (batch_size, num_nodes, input_size, order) + x = torch.reshape(x, shape=[batch_size * self._num_nodes, input_size * num_matrices]) + + weights = self._gconv_params.get_weights((input_size * num_matrices, output_size)) + x = torch.matmul(x, weights) # (batch_size * self._num_nodes, output_size) + + biases = self._gconv_params.get_biases(output_size, bias_start) + x += biases + # Reshape res back to 2D: (batch_size, num_node, state_dim) -> (batch_size, num_node * state_dim) + return torch.reshape(x, [batch_size, self._num_nodes * output_size]) diff --git a/models/STDEN/stden_model.py b/models/STDEN/stden_model.py new file mode 100644 index 0000000..0b6ae5f --- /dev/null +++ b/models/STDEN/stden_model.py @@ -0,0 +1,181 @@ +import torch +import torch.nn as nn +from torch.nn.modules.rnn import GRU +from models.STDEN.ode_func import ODEFunc +from models.STDEN.diffeq_solver import DiffeqSolver +from models.STDEN import utils +from data.graph_loader import load_graph + +class EncoderAttrs: + """编码器属性配置类""" + def __init__(self, config, adj_mx): + self.adj_mx = adj_mx + self.num_nodes = adj_mx.shape[0] + self.num_edges = (adj_mx > 0.).sum() + self.gcn_step = int(config.get('gcn_step', 2)) + self.filter_type = config.get('filter_type', 'default') + self.num_rnn_layers = int(config.get('num_rnn_layers', 1)) + self.rnn_units = int(config.get('rnn_units')) + self.latent_dim = int(config.get('latent_dim', 4)) + + +class STDENModel(nn.Module, EncoderAttrs): + """STDEN主模型:时空微分方程网络""" + def __init__(self, config): + nn.Module.__init__(self) + adj_mx = load_graph(config) + EncoderAttrs.__init__(self, config['model'], adj_mx) + + # 识别网络 + self.encoder_z0 = Encoder_z0_RNN(config['model'], adj_mx) + + model_kwargs = config['model'] + # ODE求解器配置 + self.n_traj_samples = int(model_kwargs.get('n_traj_samples', 1)) + self.ode_method = model_kwargs.get('ode_method', 'dopri5') + self.atol = float(model_kwargs.get('odeint_atol', 1e-4)) + self.rtol = float(model_kwargs.get('odeint_rtol', 1e-3)) + self.num_gen_layer = int(model_kwargs.get('gen_layers', 1)) + self.ode_gen_dim = int(model_kwargs.get('gen_dim', 64)) + + # 创建ODE函数和求解器 + odefunc = ODEFunc( + self.ode_gen_dim, self.latent_dim, adj_mx, + self.gcn_step, self.num_nodes, filter_type=self.filter_type + ) + + self.diffeq_solver = DiffeqSolver( + odefunc, self.ode_method, self.latent_dim, + odeint_rtol=self.rtol, odeint_atol=self.atol + ) + + # 潜在特征保存设置 + self.save_latent = bool(model_kwargs.get('save_latent', False)) + self.latent_feat = None + + # 解码器 + self.horizon = int(model_kwargs.get('horizon', 1)) + self.out_feat = int(model_kwargs.get('output_dim', 1)) + self.decoder = Decoder( + self.out_feat, adj_mx, self.num_nodes, self.num_edges + ) + + def forward(self, inputs, labels=None, batches_seen=None): + """ + seq2seq前向传播 + :param inputs: (seq_len, batch_size, num_edges * input_dim) + :param labels: (horizon, batch_size, num_edges * output_dim) + :param batches_seen: 已见批次数量 + :return: outputs: (horizon, batch_size, num_edges * output_dim) + """ + # 编码初始潜在状态 + B, T, N, C = inputs.shape + inputs = inputs.view(T, B, N * C) + first_point_mu, first_point_std = self.encoder_z0(inputs) + + # 采样轨迹 + means_z0 = first_point_mu.repeat(self.n_traj_samples, 1, 1) + sigma_z0 = first_point_std.repeat(self.n_traj_samples, 1, 1) + first_point_enc = utils.sample_standard_gaussian(means_z0, sigma_z0) + + # 时间步预测 + time_steps_to_predict = torch.arange(start=0, end=self.horizon, step=1).float() + time_steps_to_predict = time_steps_to_predict / len(time_steps_to_predict) + + # ODE求解 + sol_ys, fe = self.diffeq_solver(first_point_enc, time_steps_to_predict) + + if self.save_latent: + self.latent_feat = torch.mean(sol_ys.detach(), axis=1) + # 解码输出 + outputs = self.decoder(sol_ys) + + outputs = outputs.view(B, T, N, C) + + return outputs, fe + + +class Encoder_z0_RNN(nn.Module, EncoderAttrs): + """RNN编码器:将输入序列编码为初始潜在状态""" + def __init__(self, config, adj_mx): + nn.Module.__init__(self) + EncoderAttrs.__init__(self, config, adj_mx) + + self.recg_type = config.get('recg_type', 'gru') + self.input_dim = int(config.get('input_dim', 1)) + + if self.recg_type == 'gru': + self.gru_rnn = GRU(self.input_dim, self.rnn_units) + else: + raise NotImplementedError("只支持'gru'识别网络") + + # 隐藏状态到z0的映射 + self.inv_grad = utils.graph_grad(adj_mx).transpose(-2, -1) + self.inv_grad[self.inv_grad != 0.] = 0.5 + + self.hiddens_to_z0 = nn.Sequential( + nn.Linear(self.rnn_units, 50), + nn.Tanh(), + nn.Linear(50, self.latent_dim * 2) + ) + utils.init_network_weights(self.hiddens_to_z0) + + def forward(self, inputs): + """ + 编码器前向传播 + :param inputs: (seq_len, batch_size, num_edges * input_dim) + :return: mean, std: (1, batch_size, latent_dim) + """ + seq_len, batch_size = inputs.size(0), inputs.size(1) + + # 重塑输入并处理 + inputs = inputs.reshape(seq_len, batch_size, self.num_nodes, self.input_dim) + inputs = inputs.reshape(seq_len, batch_size * self.num_nodes, self.input_dim) + + # GRU处理 + outputs, _ = self.gru_rnn(inputs) + last_output = outputs[-1] + + # 重塑并转换维度 + last_output = torch.reshape(last_output, (batch_size, self.num_nodes, -1)) + last_output = torch.transpose(last_output, (-2, -1)) + last_output = torch.matmul(last_output, self.inv_grad).transpose(-2, -1) + + # 生成均值和标准差 + mean, std = utils.split_last_dim(self.hiddens_to_z0(last_output)) + mean = mean.reshape(batch_size, -1) + std = std.reshape(batch_size, -1).abs() + + return mean.unsqueeze(0), std.unsqueeze(0) + + +class Decoder(nn.Module): + """解码器:将潜在状态解码为输出""" + def __init__(self, output_dim, adj_mx, num_nodes, num_edges): + super(Decoder, self).__init__() + self.num_nodes = num_nodes + self.num_edges = num_edges + self.grap_grad = utils.graph_grad(adj_mx) + self.output_dim = output_dim + + def forward(self, inputs): + """ + :param inputs: (horizon, n_traj_samples, batch_size, num_nodes * latent_dim) + :return: outputs: (horizon, batch_size, num_edges * output_dim) + """ + horizon, n_traj_samples, batch_size = inputs.size()[:3] + + # 重塑输入 + inputs = inputs.reshape(horizon, n_traj_samples, batch_size, self.num_nodes, -1).transpose(-2, -1) + latent_dim = inputs.size(-2) + + # 图梯度变换:从节点到边 + outputs = torch.matmul(inputs, self.grap_grad) + + # 重塑并平均采样轨迹 + outputs = outputs.reshape(horizon, n_traj_samples, batch_size, latent_dim, self.num_nodes, self.output_dim) + outputs = torch.mean(torch.mean(outputs, axis=3), axis=1) + outputs = outputs.reshape(horizon, batch_size, -1) + + return outputs + diff --git a/models/STDEN/utils.py b/models/STDEN/utils.py new file mode 100644 index 0000000..a17e9cc --- /dev/null +++ b/models/STDEN/utils.py @@ -0,0 +1,234 @@ +import logging +import numpy as np +import os +import time +import scipy.sparse as sp +import sys +import torch +import torch.nn as nn + + +class DataLoader(object): + def __init__(self, xs, ys, batch_size, pad_with_last_sample=True, shuffle=False): + """ + + :param xs: + :param ys: + :param batch_size: + :param pad_with_last_sample: pad with the last sample to make number of samples divisible to batch_size. + """ + self.batch_size = batch_size + self.current_ind = 0 + if pad_with_last_sample: + num_padding = (batch_size - (len(xs) % batch_size)) % batch_size + x_padding = np.repeat(xs[-1:], num_padding, axis=0) + y_padding = np.repeat(ys[-1:], num_padding, axis=0) + xs = np.concatenate([xs, x_padding], axis=0) + ys = np.concatenate([ys, y_padding], axis=0) + self.size = len(xs) + self.num_batch = int(self.size // self.batch_size) + if shuffle: + permutation = np.random.permutation(self.size) + xs, ys = xs[permutation], ys[permutation] + self.xs = xs + self.ys = ys + + def get_iterator(self): + self.current_ind = 0 + + def _wrapper(): + while self.current_ind < self.num_batch: + start_ind = self.batch_size * self.current_ind + end_ind = min(self.size, self.batch_size * (self.current_ind + 1)) + x_i = self.xs[start_ind: end_ind, ...] + y_i = self.ys[start_ind: end_ind, ...] + yield (x_i, y_i) + self.current_ind += 1 + + return _wrapper() + + +class StandardScaler: + """ + Standard the input + """ + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def transform(self, data): + return (data - self.mean) / self.std + + def inverse_transform(self, data): + return (data * self.std) + self.mean + + +def calculate_random_walk_matrix(adj_mx): + adj_mx = sp.coo_matrix(adj_mx) + d = np.array(adj_mx.sum(1)) + d_inv = np.power(d, -1).flatten() + d_inv[np.isinf(d_inv)] = 0. + d_mat_inv = sp.diags(d_inv) + random_walk_mx = d_mat_inv.dot(adj_mx).tocoo() + return random_walk_mx + + +def config_logging(log_dir, log_filename='info.log', level=logging.INFO): + # Add file handler and stdout handler + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + # Create the log directory if necessary. + try: + os.makedirs(log_dir) + except OSError: + pass + file_handler = logging.FileHandler(os.path.join(log_dir, log_filename)) + file_handler.setFormatter(formatter) + file_handler.setLevel(level=level) + # Add console handler. + console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(console_formatter) + console_handler.setLevel(level=level) + logging.basicConfig(handlers=[file_handler, console_handler], level=level) + + +def get_logger(log_dir, name, log_filename='info.log', level=logging.INFO): + logger = logging.getLogger(name) + logger.setLevel(level) + # Add file handler and stdout handler + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + file_handler = logging.FileHandler(os.path.join(log_dir, log_filename)) + file_handler.setFormatter(formatter) + # Add console handler. + console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(console_formatter) + logger.addHandler(file_handler) + logger.addHandler(console_handler) + # Add google cloud log handler + logger.info('Log directory: %s', log_dir) + return logger + + +def get_log_dir(kwargs): + log_dir = kwargs['train'].get('log_dir') + if log_dir is None: + batch_size = kwargs['data'].get('batch_size') + + filter_type = kwargs['model'].get('filter_type') + gcn_step = kwargs['model'].get('gcn_step') + horizon = kwargs['model'].get('horizon') + latent_dim = kwargs['model'].get('latent_dim') + n_traj_samples = kwargs['model'].get('n_traj_samples') + ode_method = kwargs['model'].get('ode_method') + + seq_len = kwargs['model'].get('seq_len') + rnn_units = kwargs['model'].get('rnn_units') + recg_type = kwargs['model'].get('recg_type') + + if filter_type == 'unkP': + filter_type_abbr = 'UP' + elif filter_type == 'IncP': + filter_type_abbr = 'NV' + else: + filter_type_abbr = 'DF' + + run_id = 'STDEN_%s-%d_%s-%d_L-%d_N-%d_M-%s_bs-%d_%d-%d_%s/' % ( + recg_type, rnn_units, filter_type_abbr, gcn_step, latent_dim, n_traj_samples, ode_method, batch_size, + seq_len, horizon, time.strftime('%m%d%H%M%S')) + base_dir = kwargs.get('log_base_dir') + log_dir = os.path.join(base_dir, run_id) + if not os.path.exists(log_dir): + os.makedirs(log_dir) + return log_dir + + +def load_dataset(dataset_dir, batch_size, val_batch_size=None, **kwargs): + if ('BJ' in dataset_dir): + data = dict(np.load(os.path.join(dataset_dir, 'flow.npz'))) # convert readonly NpzFile to writable dict Object + for category in ['train', 'val', 'test']: + data['x_' + category] = data['x_' + category] # [..., :4] # ignore the time index + else: + data = {} + for category in ['train', 'val', 'test']: + cat_data = np.load(os.path.join(dataset_dir, category + '.npz')) + data['x_' + category] = cat_data['x'] + data['y_' + category] = cat_data['y'] + scaler = StandardScaler(mean=data['x_train'].mean(), std=data['x_train'].std()) + # Data format + for category in ['train', 'val', 'test']: + data['x_' + category] = scaler.transform(data['x_' + category]) + data['y_' + category] = scaler.transform(data['y_' + category]) + data['train_loader'] = DataLoader(data['x_train'], data['y_train'], batch_size, shuffle=True) + data['val_loader'] = DataLoader(data['x_val'], data['y_val'], val_batch_size, shuffle=False) + data['test_loader'] = DataLoader(data['x_test'], data['y_test'], val_batch_size, shuffle=False) + data['scaler'] = scaler + + return data + + +def load_graph_data(pkl_filename): + adj_mx = np.load(pkl_filename) + return adj_mx + + +def graph_grad(adj_mx): + """Fetch the graph gradient operator.""" + num_nodes = adj_mx.shape[0] + + num_edges = (adj_mx > 0.).sum() + grad = torch.zeros(num_nodes, num_edges) + e = 0 + for i in range(num_nodes): + for j in range(num_nodes): + if adj_mx[i, j] == 0: + continue + grad[i, e] = 1. + grad[j, e] = -1. + e += 1 + return grad + + +def init_network_weights(net, std=0.1): + """ + Just for nn.Linear net. + """ + for m in net.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, mean=0, std=std) + nn.init.constant_(m.bias, val=0) + + +def split_last_dim(data): + last_dim = data.size()[-1] + last_dim = last_dim // 2 + + res = data[..., :last_dim], data[..., last_dim:] + return res + + +def get_device(tensor): + device = torch.device("cpu") + if tensor.is_cuda: + device = tensor.get_device() + return device + + +def sample_standard_gaussian(mu, sigma): + device = get_device(mu) + + d = torch.distributions.normal.Normal(torch.Tensor([0.]).to(device), torch.Tensor([1.]).to(device)) + r = d.sample(mu.size()).squeeze(-1) + return r * sigma.float() + mu.float() + + +def create_net(n_inputs, n_outputs, n_layers=0, n_units=100, nonlinear=nn.Tanh): + layers = [nn.Linear(n_inputs, n_units)] + for i in range(n_layers): + layers.append(nonlinear()) + layers.append(nn.Linear(n_units, n_units)) + + layers.append(nonlinear()) + layers.append(nn.Linear(n_units, n_outputs)) + return nn.Sequential(*layers)