diff --git a/.gitignore b/.gitignore index 9de7b6a..cf2451b 100644 --- a/.gitignore +++ b/.gitignore @@ -168,6 +168,8 @@ STDEN/ models/gpt2/ pre-trained/ +# 注意:models/STDEN/ 是代码目录,不应该被忽略 + # 数据集文件类型屏蔽 *.csv *.npz diff --git a/models/STDEN/ode_func.py b/models/STDEN/ode_func.py index 10456e3..1822560 100644 --- a/models/STDEN/ode_func.py +++ b/models/STDEN/ode_func.py @@ -4,7 +4,7 @@ import torch.nn as nn from models.STDEN import utils -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# 移除全局device设置,让模型自己决定设备 class LayerParams: def __init__(self, rnn_network: nn.Module, layer_type: str): @@ -15,7 +15,7 @@ class LayerParams: def get_weights(self, shape): if shape not in self._params_dict: - nn_param = nn.Parameter(torch.empty(*shape, device=device)) + nn_param = nn.Parameter(torch.empty(*shape)) nn.init.xavier_normal_(nn_param) self._params_dict[shape] = nn_param self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)), @@ -24,7 +24,7 @@ class LayerParams: def get_biases(self, length, bias_start=0.0): if length not in self._biases_dict: - biases = nn.Parameter(torch.empty(length, device=device)) + biases = nn.Parameter(torch.empty(length)) nn.init.constant_(biases, bias_start) self._biases_dict[length] = biases self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)), @@ -77,7 +77,7 @@ class ODEFunc(nn.Module): indices = np.column_stack((L.row, L.col)) # this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L) indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))] - L = torch.sparse_coo_tensor(indices.T, L.data, L.shape, device=device) + L = torch.sparse_coo_tensor(indices.T, L.data.astype(np.float32), L.shape, dtype=torch.float32) return L def forward(self, t_local, y, backwards = False): diff --git a/models/STDEN/stden_model.py b/models/STDEN/stden_model.py index 0b6ae5f..8b13789 100644 --- a/models/STDEN/stden_model.py +++ b/models/STDEN/stden_model.py @@ -1,181 +1 @@ -import torch -import torch.nn as nn -from torch.nn.modules.rnn import GRU -from models.STDEN.ode_func import ODEFunc -from models.STDEN.diffeq_solver import DiffeqSolver -from models.STDEN import utils -from data.graph_loader import load_graph - -class EncoderAttrs: - """编码器属性配置类""" - def __init__(self, config, adj_mx): - self.adj_mx = adj_mx - self.num_nodes = adj_mx.shape[0] - self.num_edges = (adj_mx > 0.).sum() - self.gcn_step = int(config.get('gcn_step', 2)) - self.filter_type = config.get('filter_type', 'default') - self.num_rnn_layers = int(config.get('num_rnn_layers', 1)) - self.rnn_units = int(config.get('rnn_units')) - self.latent_dim = int(config.get('latent_dim', 4)) - - -class STDENModel(nn.Module, EncoderAttrs): - """STDEN主模型:时空微分方程网络""" - def __init__(self, config): - nn.Module.__init__(self) - adj_mx = load_graph(config) - EncoderAttrs.__init__(self, config['model'], adj_mx) - - # 识别网络 - self.encoder_z0 = Encoder_z0_RNN(config['model'], adj_mx) - - model_kwargs = config['model'] - # ODE求解器配置 - self.n_traj_samples = int(model_kwargs.get('n_traj_samples', 1)) - self.ode_method = model_kwargs.get('ode_method', 'dopri5') - self.atol = float(model_kwargs.get('odeint_atol', 1e-4)) - self.rtol = float(model_kwargs.get('odeint_rtol', 1e-3)) - self.num_gen_layer = int(model_kwargs.get('gen_layers', 1)) - self.ode_gen_dim = int(model_kwargs.get('gen_dim', 64)) - - # 创建ODE函数和求解器 - odefunc = ODEFunc( - self.ode_gen_dim, self.latent_dim, adj_mx, - self.gcn_step, self.num_nodes, filter_type=self.filter_type - ) - - self.diffeq_solver = DiffeqSolver( - odefunc, self.ode_method, self.latent_dim, - odeint_rtol=self.rtol, odeint_atol=self.atol - ) - - # 潜在特征保存设置 - self.save_latent = bool(model_kwargs.get('save_latent', False)) - self.latent_feat = None - - # 解码器 - self.horizon = int(model_kwargs.get('horizon', 1)) - self.out_feat = int(model_kwargs.get('output_dim', 1)) - self.decoder = Decoder( - self.out_feat, adj_mx, self.num_nodes, self.num_edges - ) - - def forward(self, inputs, labels=None, batches_seen=None): - """ - seq2seq前向传播 - :param inputs: (seq_len, batch_size, num_edges * input_dim) - :param labels: (horizon, batch_size, num_edges * output_dim) - :param batches_seen: 已见批次数量 - :return: outputs: (horizon, batch_size, num_edges * output_dim) - """ - # 编码初始潜在状态 - B, T, N, C = inputs.shape - inputs = inputs.view(T, B, N * C) - first_point_mu, first_point_std = self.encoder_z0(inputs) - - # 采样轨迹 - means_z0 = first_point_mu.repeat(self.n_traj_samples, 1, 1) - sigma_z0 = first_point_std.repeat(self.n_traj_samples, 1, 1) - first_point_enc = utils.sample_standard_gaussian(means_z0, sigma_z0) - - # 时间步预测 - time_steps_to_predict = torch.arange(start=0, end=self.horizon, step=1).float() - time_steps_to_predict = time_steps_to_predict / len(time_steps_to_predict) - - # ODE求解 - sol_ys, fe = self.diffeq_solver(first_point_enc, time_steps_to_predict) - - if self.save_latent: - self.latent_feat = torch.mean(sol_ys.detach(), axis=1) - # 解码输出 - outputs = self.decoder(sol_ys) - - outputs = outputs.view(B, T, N, C) - - return outputs, fe - - -class Encoder_z0_RNN(nn.Module, EncoderAttrs): - """RNN编码器:将输入序列编码为初始潜在状态""" - def __init__(self, config, adj_mx): - nn.Module.__init__(self) - EncoderAttrs.__init__(self, config, adj_mx) - - self.recg_type = config.get('recg_type', 'gru') - self.input_dim = int(config.get('input_dim', 1)) - - if self.recg_type == 'gru': - self.gru_rnn = GRU(self.input_dim, self.rnn_units) - else: - raise NotImplementedError("只支持'gru'识别网络") - - # 隐藏状态到z0的映射 - self.inv_grad = utils.graph_grad(adj_mx).transpose(-2, -1) - self.inv_grad[self.inv_grad != 0.] = 0.5 - - self.hiddens_to_z0 = nn.Sequential( - nn.Linear(self.rnn_units, 50), - nn.Tanh(), - nn.Linear(50, self.latent_dim * 2) - ) - utils.init_network_weights(self.hiddens_to_z0) - - def forward(self, inputs): - """ - 编码器前向传播 - :param inputs: (seq_len, batch_size, num_edges * input_dim) - :return: mean, std: (1, batch_size, latent_dim) - """ - seq_len, batch_size = inputs.size(0), inputs.size(1) - - # 重塑输入并处理 - inputs = inputs.reshape(seq_len, batch_size, self.num_nodes, self.input_dim) - inputs = inputs.reshape(seq_len, batch_size * self.num_nodes, self.input_dim) - - # GRU处理 - outputs, _ = self.gru_rnn(inputs) - last_output = outputs[-1] - - # 重塑并转换维度 - last_output = torch.reshape(last_output, (batch_size, self.num_nodes, -1)) - last_output = torch.transpose(last_output, (-2, -1)) - last_output = torch.matmul(last_output, self.inv_grad).transpose(-2, -1) - - # 生成均值和标准差 - mean, std = utils.split_last_dim(self.hiddens_to_z0(last_output)) - mean = mean.reshape(batch_size, -1) - std = std.reshape(batch_size, -1).abs() - - return mean.unsqueeze(0), std.unsqueeze(0) - - -class Decoder(nn.Module): - """解码器:将潜在状态解码为输出""" - def __init__(self, output_dim, adj_mx, num_nodes, num_edges): - super(Decoder, self).__init__() - self.num_nodes = num_nodes - self.num_edges = num_edges - self.grap_grad = utils.graph_grad(adj_mx) - self.output_dim = output_dim - - def forward(self, inputs): - """ - :param inputs: (horizon, n_traj_samples, batch_size, num_nodes * latent_dim) - :return: outputs: (horizon, batch_size, num_edges * output_dim) - """ - horizon, n_traj_samples, batch_size = inputs.size()[:3] - - # 重塑输入 - inputs = inputs.reshape(horizon, n_traj_samples, batch_size, self.num_nodes, -1).transpose(-2, -1) - latent_dim = inputs.size(-2) - - # 图梯度变换:从节点到边 - outputs = torch.matmul(inputs, self.grap_grad) - - # 重塑并平均采样轨迹 - outputs = outputs.reshape(horizon, n_traj_samples, batch_size, latent_dim, self.num_nodes, self.output_dim) - outputs = torch.mean(torch.mean(outputs, axis=3), axis=1) - outputs = outputs.reshape(horizon, batch_size, -1) - - return outputs