Compare commits
2 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
abdd3165b8 | |
|
|
626bb4d2bb |
|
|
@ -168,6 +168,8 @@ STDEN/
|
||||||
models/gpt2/
|
models/gpt2/
|
||||||
pre-trained/
|
pre-trained/
|
||||||
|
|
||||||
|
# 注意:models/STDEN/ 是代码目录,不应该被忽略
|
||||||
|
|
||||||
# 数据集文件类型屏蔽
|
# 数据集文件类型屏蔽
|
||||||
*.csv
|
*.csv
|
||||||
*.npz
|
*.npz
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
||||||
|
|
||||||
from models.STDEN import utils
|
from models.STDEN import utils
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
# 移除全局device设置,让模型自己决定设备
|
||||||
|
|
||||||
class LayerParams:
|
class LayerParams:
|
||||||
def __init__(self, rnn_network: nn.Module, layer_type: str):
|
def __init__(self, rnn_network: nn.Module, layer_type: str):
|
||||||
|
|
@ -15,7 +15,7 @@ class LayerParams:
|
||||||
|
|
||||||
def get_weights(self, shape):
|
def get_weights(self, shape):
|
||||||
if shape not in self._params_dict:
|
if shape not in self._params_dict:
|
||||||
nn_param = nn.Parameter(torch.empty(*shape, device=device))
|
nn_param = nn.Parameter(torch.empty(*shape))
|
||||||
nn.init.xavier_normal_(nn_param)
|
nn.init.xavier_normal_(nn_param)
|
||||||
self._params_dict[shape] = nn_param
|
self._params_dict[shape] = nn_param
|
||||||
self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)),
|
self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)),
|
||||||
|
|
@ -24,7 +24,7 @@ class LayerParams:
|
||||||
|
|
||||||
def get_biases(self, length, bias_start=0.0):
|
def get_biases(self, length, bias_start=0.0):
|
||||||
if length not in self._biases_dict:
|
if length not in self._biases_dict:
|
||||||
biases = nn.Parameter(torch.empty(length, device=device))
|
biases = nn.Parameter(torch.empty(length))
|
||||||
nn.init.constant_(biases, bias_start)
|
nn.init.constant_(biases, bias_start)
|
||||||
self._biases_dict[length] = biases
|
self._biases_dict[length] = biases
|
||||||
self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)),
|
self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)),
|
||||||
|
|
@ -77,7 +77,7 @@ class ODEFunc(nn.Module):
|
||||||
indices = np.column_stack((L.row, L.col))
|
indices = np.column_stack((L.row, L.col))
|
||||||
# this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L)
|
# this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L)
|
||||||
indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))]
|
indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))]
|
||||||
L = torch.sparse_coo_tensor(indices.T, L.data, L.shape, device=device)
|
L = torch.sparse_coo_tensor(indices.T, L.data.astype(np.float32), L.shape, dtype=torch.float32)
|
||||||
return L
|
return L
|
||||||
|
|
||||||
def forward(self, t_local, y, backwards = False):
|
def forward(self, t_local, y, backwards = False):
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,19 @@ class STDENModel(nn.Module, EncoderAttrs):
|
||||||
adj_mx = load_graph(config)
|
adj_mx = load_graph(config)
|
||||||
EncoderAttrs.__init__(self, config['model'], adj_mx)
|
EncoderAttrs.__init__(self, config['model'], adj_mx)
|
||||||
|
|
||||||
|
# 输入输出维度配置
|
||||||
|
self.input_dim = int(config['model'].get('input_dim', 1))
|
||||||
|
self.output_dim = int(config['model'].get('output_dim', 1))
|
||||||
|
|
||||||
|
# Node到Edge的转换层
|
||||||
|
self.node_to_edge = nn.Linear(self.num_nodes * self.input_dim, self.num_edges * self.input_dim)
|
||||||
|
# Edge到Node的转换层
|
||||||
|
self.edge_to_node = nn.Linear(self.num_edges * self.output_dim, self.num_nodes * self.output_dim)
|
||||||
|
|
||||||
|
# 初始化转换层权重
|
||||||
|
utils.init_network_weights(self.node_to_edge)
|
||||||
|
utils.init_network_weights(self.edge_to_node)
|
||||||
|
|
||||||
# 识别网络
|
# 识别网络
|
||||||
self.encoder_z0 = Encoder_z0_RNN(config['model'], adj_mx)
|
self.encoder_z0 = Encoder_z0_RNN(config['model'], adj_mx)
|
||||||
|
|
||||||
|
|
@ -63,15 +76,20 @@ class STDENModel(nn.Module, EncoderAttrs):
|
||||||
def forward(self, inputs, labels=None, batches_seen=None):
|
def forward(self, inputs, labels=None, batches_seen=None):
|
||||||
"""
|
"""
|
||||||
seq2seq前向传播
|
seq2seq前向传播
|
||||||
:param inputs: (seq_len, batch_size, num_edges * input_dim)
|
:param inputs: (batch_size, seq_len, num_nodes, input_dim) - 节点格式输入
|
||||||
:param labels: (horizon, batch_size, num_edges * output_dim)
|
:param labels: (batch_size, horizon, num_nodes, output_dim) - 节点格式标签
|
||||||
:param batches_seen: 已见批次数量
|
:param batches_seen: 已见批次数量
|
||||||
:return: outputs: (horizon, batch_size, num_edges * output_dim)
|
:return: outputs: (batch_size, horizon, num_nodes, output_dim) - 节点格式输出
|
||||||
"""
|
"""
|
||||||
# 编码初始潜在状态
|
# 输入格式转换:从node格式转换为edge格式
|
||||||
B, T, N, C = inputs.shape
|
B, T, N, C = inputs.shape
|
||||||
inputs = inputs.view(T, B, N * C)
|
inputs_node = inputs.view(T, B, N * C) # (T, B, N*C)
|
||||||
first_point_mu, first_point_std = self.encoder_z0(inputs)
|
|
||||||
|
# 将node格式转换为edge格式
|
||||||
|
inputs_edge = self.node_to_edge(inputs_node) # (T, B, E*C)
|
||||||
|
|
||||||
|
# 编码初始潜在状态
|
||||||
|
first_point_mu, first_point_std = self.encoder_z0(inputs_edge)
|
||||||
|
|
||||||
# 采样轨迹
|
# 采样轨迹
|
||||||
means_z0 = first_point_mu.repeat(self.n_traj_samples, 1, 1)
|
means_z0 = first_point_mu.repeat(self.n_traj_samples, 1, 1)
|
||||||
|
|
@ -87,10 +105,16 @@ class STDENModel(nn.Module, EncoderAttrs):
|
||||||
|
|
||||||
if self.save_latent:
|
if self.save_latent:
|
||||||
self.latent_feat = torch.mean(sol_ys.detach(), axis=1)
|
self.latent_feat = torch.mean(sol_ys.detach(), axis=1)
|
||||||
# 解码输出
|
|
||||||
outputs = self.decoder(sol_ys)
|
|
||||||
|
|
||||||
outputs = outputs.view(B, T, N, C)
|
# 解码输出(edge格式)
|
||||||
|
outputs_edge = self.decoder(sol_ys) # (horizon, B, E*output_dim)
|
||||||
|
|
||||||
|
# 将edge格式转换回node格式
|
||||||
|
outputs_node = self.edge_to_node(outputs_edge) # (horizon, B, N*output_dim)
|
||||||
|
|
||||||
|
# 重塑为最终输出格式
|
||||||
|
outputs = outputs_node.view(self.horizon, B, N, self.output_dim)
|
||||||
|
outputs = outputs.transpose(0, 1) # (B, horizon, N, output_dim)
|
||||||
|
|
||||||
return outputs, fe
|
return outputs, fe
|
||||||
|
|
||||||
|
|
@ -128,17 +152,17 @@ class Encoder_z0_RNN(nn.Module, EncoderAttrs):
|
||||||
"""
|
"""
|
||||||
seq_len, batch_size = inputs.size(0), inputs.size(1)
|
seq_len, batch_size = inputs.size(0), inputs.size(1)
|
||||||
|
|
||||||
# 重塑输入并处理
|
# 重塑输入并处理 - 现在输入是edge格式
|
||||||
inputs = inputs.reshape(seq_len, batch_size, self.num_nodes, self.input_dim)
|
inputs = inputs.reshape(seq_len, batch_size, self.num_edges, self.input_dim)
|
||||||
inputs = inputs.reshape(seq_len, batch_size * self.num_nodes, self.input_dim)
|
inputs = inputs.reshape(seq_len, batch_size * self.num_edges, self.input_dim)
|
||||||
|
|
||||||
# GRU处理
|
# GRU处理
|
||||||
outputs, _ = self.gru_rnn(inputs)
|
outputs, _ = self.gru_rnn(inputs)
|
||||||
last_output = outputs[-1]
|
last_output = outputs[-1]
|
||||||
|
|
||||||
# 重塑并转换维度
|
# 重塑并转换维度 - 从edge格式转换回node格式
|
||||||
last_output = torch.reshape(last_output, (batch_size, self.num_nodes, -1))
|
last_output = torch.reshape(last_output, (batch_size, self.num_edges, -1))
|
||||||
last_output = torch.transpose(last_output, (-2, -1))
|
last_output = torch.transpose(last_output, -2, -1)
|
||||||
last_output = torch.matmul(last_output, self.inv_grad).transpose(-2, -1)
|
last_output = torch.matmul(last_output, self.inv_grad).transpose(-2, -1)
|
||||||
|
|
||||||
# 生成均值和标准差
|
# 生成均值和标准差
|
||||||
|
|
@ -173,7 +197,7 @@ class Decoder(nn.Module):
|
||||||
outputs = torch.matmul(inputs, self.grap_grad)
|
outputs = torch.matmul(inputs, self.grap_grad)
|
||||||
|
|
||||||
# 重塑并平均采样轨迹
|
# 重塑并平均采样轨迹
|
||||||
outputs = outputs.reshape(horizon, n_traj_samples, batch_size, latent_dim, self.num_nodes, self.output_dim)
|
outputs = outputs.reshape(horizon, n_traj_samples, batch_size, latent_dim, self.num_edges, self.output_dim)
|
||||||
outputs = torch.mean(torch.mean(outputs, axis=3), axis=1)
|
outputs = torch.mean(torch.mean(outputs, axis=3), axis=1)
|
||||||
outputs = outputs.reshape(horizon, batch_size, -1)
|
outputs = outputs.reshape(horizon, batch_size, -1)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue