为STDEN模型添加node到edge的Linear转换层
- 在STDENModel中添加node_to_edge和edge_to_node转换层 - 修改forward方法以处理node_num输入并输出node_num格式 - 更新编码器以处理edge格式的中间数据 - 修复解码器中的维度计算问题 - 解决设备不匹配和数据类型不一致问题 - 更新.gitignore以允许models/STDEN/代码目录被跟踪 现在模型可以接受node_num格式的输入,内部转换为edge_num进行处理,最后转换回node_num输出。
This commit is contained in:
parent
e65be7d668
commit
626bb4d2bb
|
|
@ -168,6 +168,8 @@ STDEN/
|
||||||
models/gpt2/
|
models/gpt2/
|
||||||
pre-trained/
|
pre-trained/
|
||||||
|
|
||||||
|
# 注意:models/STDEN/ 是代码目录,不应该被忽略
|
||||||
|
|
||||||
# 数据集文件类型屏蔽
|
# 数据集文件类型屏蔽
|
||||||
*.csv
|
*.csv
|
||||||
*.npz
|
*.npz
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
||||||
|
|
||||||
from models.STDEN import utils
|
from models.STDEN import utils
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
# 移除全局device设置,让模型自己决定设备
|
||||||
|
|
||||||
class LayerParams:
|
class LayerParams:
|
||||||
def __init__(self, rnn_network: nn.Module, layer_type: str):
|
def __init__(self, rnn_network: nn.Module, layer_type: str):
|
||||||
|
|
@ -15,7 +15,7 @@ class LayerParams:
|
||||||
|
|
||||||
def get_weights(self, shape):
|
def get_weights(self, shape):
|
||||||
if shape not in self._params_dict:
|
if shape not in self._params_dict:
|
||||||
nn_param = nn.Parameter(torch.empty(*shape, device=device))
|
nn_param = nn.Parameter(torch.empty(*shape))
|
||||||
nn.init.xavier_normal_(nn_param)
|
nn.init.xavier_normal_(nn_param)
|
||||||
self._params_dict[shape] = nn_param
|
self._params_dict[shape] = nn_param
|
||||||
self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)),
|
self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)),
|
||||||
|
|
@ -24,7 +24,7 @@ class LayerParams:
|
||||||
|
|
||||||
def get_biases(self, length, bias_start=0.0):
|
def get_biases(self, length, bias_start=0.0):
|
||||||
if length not in self._biases_dict:
|
if length not in self._biases_dict:
|
||||||
biases = nn.Parameter(torch.empty(length, device=device))
|
biases = nn.Parameter(torch.empty(length))
|
||||||
nn.init.constant_(biases, bias_start)
|
nn.init.constant_(biases, bias_start)
|
||||||
self._biases_dict[length] = biases
|
self._biases_dict[length] = biases
|
||||||
self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)),
|
self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)),
|
||||||
|
|
@ -77,7 +77,7 @@ class ODEFunc(nn.Module):
|
||||||
indices = np.column_stack((L.row, L.col))
|
indices = np.column_stack((L.row, L.col))
|
||||||
# this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L)
|
# this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L)
|
||||||
indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))]
|
indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))]
|
||||||
L = torch.sparse_coo_tensor(indices.T, L.data, L.shape, device=device)
|
L = torch.sparse_coo_tensor(indices.T, L.data.astype(np.float32), L.shape, dtype=torch.float32)
|
||||||
return L
|
return L
|
||||||
|
|
||||||
def forward(self, t_local, y, backwards = False):
|
def forward(self, t_local, y, backwards = False):
|
||||||
|
|
|
||||||
|
|
@ -1,181 +1 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.nn.modules.rnn import GRU
|
|
||||||
from models.STDEN.ode_func import ODEFunc
|
|
||||||
from models.STDEN.diffeq_solver import DiffeqSolver
|
|
||||||
from models.STDEN import utils
|
|
||||||
from data.graph_loader import load_graph
|
|
||||||
|
|
||||||
class EncoderAttrs:
|
|
||||||
"""编码器属性配置类"""
|
|
||||||
def __init__(self, config, adj_mx):
|
|
||||||
self.adj_mx = adj_mx
|
|
||||||
self.num_nodes = adj_mx.shape[0]
|
|
||||||
self.num_edges = (adj_mx > 0.).sum()
|
|
||||||
self.gcn_step = int(config.get('gcn_step', 2))
|
|
||||||
self.filter_type = config.get('filter_type', 'default')
|
|
||||||
self.num_rnn_layers = int(config.get('num_rnn_layers', 1))
|
|
||||||
self.rnn_units = int(config.get('rnn_units'))
|
|
||||||
self.latent_dim = int(config.get('latent_dim', 4))
|
|
||||||
|
|
||||||
|
|
||||||
class STDENModel(nn.Module, EncoderAttrs):
|
|
||||||
"""STDEN主模型:时空微分方程网络"""
|
|
||||||
def __init__(self, config):
|
|
||||||
nn.Module.__init__(self)
|
|
||||||
adj_mx = load_graph(config)
|
|
||||||
EncoderAttrs.__init__(self, config['model'], adj_mx)
|
|
||||||
|
|
||||||
# 识别网络
|
|
||||||
self.encoder_z0 = Encoder_z0_RNN(config['model'], adj_mx)
|
|
||||||
|
|
||||||
model_kwargs = config['model']
|
|
||||||
# ODE求解器配置
|
|
||||||
self.n_traj_samples = int(model_kwargs.get('n_traj_samples', 1))
|
|
||||||
self.ode_method = model_kwargs.get('ode_method', 'dopri5')
|
|
||||||
self.atol = float(model_kwargs.get('odeint_atol', 1e-4))
|
|
||||||
self.rtol = float(model_kwargs.get('odeint_rtol', 1e-3))
|
|
||||||
self.num_gen_layer = int(model_kwargs.get('gen_layers', 1))
|
|
||||||
self.ode_gen_dim = int(model_kwargs.get('gen_dim', 64))
|
|
||||||
|
|
||||||
# 创建ODE函数和求解器
|
|
||||||
odefunc = ODEFunc(
|
|
||||||
self.ode_gen_dim, self.latent_dim, adj_mx,
|
|
||||||
self.gcn_step, self.num_nodes, filter_type=self.filter_type
|
|
||||||
)
|
|
||||||
|
|
||||||
self.diffeq_solver = DiffeqSolver(
|
|
||||||
odefunc, self.ode_method, self.latent_dim,
|
|
||||||
odeint_rtol=self.rtol, odeint_atol=self.atol
|
|
||||||
)
|
|
||||||
|
|
||||||
# 潜在特征保存设置
|
|
||||||
self.save_latent = bool(model_kwargs.get('save_latent', False))
|
|
||||||
self.latent_feat = None
|
|
||||||
|
|
||||||
# 解码器
|
|
||||||
self.horizon = int(model_kwargs.get('horizon', 1))
|
|
||||||
self.out_feat = int(model_kwargs.get('output_dim', 1))
|
|
||||||
self.decoder = Decoder(
|
|
||||||
self.out_feat, adj_mx, self.num_nodes, self.num_edges
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, inputs, labels=None, batches_seen=None):
|
|
||||||
"""
|
|
||||||
seq2seq前向传播
|
|
||||||
:param inputs: (seq_len, batch_size, num_edges * input_dim)
|
|
||||||
:param labels: (horizon, batch_size, num_edges * output_dim)
|
|
||||||
:param batches_seen: 已见批次数量
|
|
||||||
:return: outputs: (horizon, batch_size, num_edges * output_dim)
|
|
||||||
"""
|
|
||||||
# 编码初始潜在状态
|
|
||||||
B, T, N, C = inputs.shape
|
|
||||||
inputs = inputs.view(T, B, N * C)
|
|
||||||
first_point_mu, first_point_std = self.encoder_z0(inputs)
|
|
||||||
|
|
||||||
# 采样轨迹
|
|
||||||
means_z0 = first_point_mu.repeat(self.n_traj_samples, 1, 1)
|
|
||||||
sigma_z0 = first_point_std.repeat(self.n_traj_samples, 1, 1)
|
|
||||||
first_point_enc = utils.sample_standard_gaussian(means_z0, sigma_z0)
|
|
||||||
|
|
||||||
# 时间步预测
|
|
||||||
time_steps_to_predict = torch.arange(start=0, end=self.horizon, step=1).float()
|
|
||||||
time_steps_to_predict = time_steps_to_predict / len(time_steps_to_predict)
|
|
||||||
|
|
||||||
# ODE求解
|
|
||||||
sol_ys, fe = self.diffeq_solver(first_point_enc, time_steps_to_predict)
|
|
||||||
|
|
||||||
if self.save_latent:
|
|
||||||
self.latent_feat = torch.mean(sol_ys.detach(), axis=1)
|
|
||||||
# 解码输出
|
|
||||||
outputs = self.decoder(sol_ys)
|
|
||||||
|
|
||||||
outputs = outputs.view(B, T, N, C)
|
|
||||||
|
|
||||||
return outputs, fe
|
|
||||||
|
|
||||||
|
|
||||||
class Encoder_z0_RNN(nn.Module, EncoderAttrs):
|
|
||||||
"""RNN编码器:将输入序列编码为初始潜在状态"""
|
|
||||||
def __init__(self, config, adj_mx):
|
|
||||||
nn.Module.__init__(self)
|
|
||||||
EncoderAttrs.__init__(self, config, adj_mx)
|
|
||||||
|
|
||||||
self.recg_type = config.get('recg_type', 'gru')
|
|
||||||
self.input_dim = int(config.get('input_dim', 1))
|
|
||||||
|
|
||||||
if self.recg_type == 'gru':
|
|
||||||
self.gru_rnn = GRU(self.input_dim, self.rnn_units)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("只支持'gru'识别网络")
|
|
||||||
|
|
||||||
# 隐藏状态到z0的映射
|
|
||||||
self.inv_grad = utils.graph_grad(adj_mx).transpose(-2, -1)
|
|
||||||
self.inv_grad[self.inv_grad != 0.] = 0.5
|
|
||||||
|
|
||||||
self.hiddens_to_z0 = nn.Sequential(
|
|
||||||
nn.Linear(self.rnn_units, 50),
|
|
||||||
nn.Tanh(),
|
|
||||||
nn.Linear(50, self.latent_dim * 2)
|
|
||||||
)
|
|
||||||
utils.init_network_weights(self.hiddens_to_z0)
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
"""
|
|
||||||
编码器前向传播
|
|
||||||
:param inputs: (seq_len, batch_size, num_edges * input_dim)
|
|
||||||
:return: mean, std: (1, batch_size, latent_dim)
|
|
||||||
"""
|
|
||||||
seq_len, batch_size = inputs.size(0), inputs.size(1)
|
|
||||||
|
|
||||||
# 重塑输入并处理
|
|
||||||
inputs = inputs.reshape(seq_len, batch_size, self.num_nodes, self.input_dim)
|
|
||||||
inputs = inputs.reshape(seq_len, batch_size * self.num_nodes, self.input_dim)
|
|
||||||
|
|
||||||
# GRU处理
|
|
||||||
outputs, _ = self.gru_rnn(inputs)
|
|
||||||
last_output = outputs[-1]
|
|
||||||
|
|
||||||
# 重塑并转换维度
|
|
||||||
last_output = torch.reshape(last_output, (batch_size, self.num_nodes, -1))
|
|
||||||
last_output = torch.transpose(last_output, (-2, -1))
|
|
||||||
last_output = torch.matmul(last_output, self.inv_grad).transpose(-2, -1)
|
|
||||||
|
|
||||||
# 生成均值和标准差
|
|
||||||
mean, std = utils.split_last_dim(self.hiddens_to_z0(last_output))
|
|
||||||
mean = mean.reshape(batch_size, -1)
|
|
||||||
std = std.reshape(batch_size, -1).abs()
|
|
||||||
|
|
||||||
return mean.unsqueeze(0), std.unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
|
||||||
"""解码器:将潜在状态解码为输出"""
|
|
||||||
def __init__(self, output_dim, adj_mx, num_nodes, num_edges):
|
|
||||||
super(Decoder, self).__init__()
|
|
||||||
self.num_nodes = num_nodes
|
|
||||||
self.num_edges = num_edges
|
|
||||||
self.grap_grad = utils.graph_grad(adj_mx)
|
|
||||||
self.output_dim = output_dim
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
"""
|
|
||||||
:param inputs: (horizon, n_traj_samples, batch_size, num_nodes * latent_dim)
|
|
||||||
:return: outputs: (horizon, batch_size, num_edges * output_dim)
|
|
||||||
"""
|
|
||||||
horizon, n_traj_samples, batch_size = inputs.size()[:3]
|
|
||||||
|
|
||||||
# 重塑输入
|
|
||||||
inputs = inputs.reshape(horizon, n_traj_samples, batch_size, self.num_nodes, -1).transpose(-2, -1)
|
|
||||||
latent_dim = inputs.size(-2)
|
|
||||||
|
|
||||||
# 图梯度变换:从节点到边
|
|
||||||
outputs = torch.matmul(inputs, self.grap_grad)
|
|
||||||
|
|
||||||
# 重塑并平均采样轨迹
|
|
||||||
outputs = outputs.reshape(horizon, n_traj_samples, batch_size, latent_dim, self.num_nodes, self.output_dim)
|
|
||||||
outputs = torch.mean(torch.mean(outputs, axis=3), axis=1)
|
|
||||||
outputs = outputs.reshape(horizon, batch_size, -1)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue