Project-I/models/STDEN/stden_model.py

206 lines
8.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
from torch.nn.modules.rnn import GRU
from models.STDEN.ode_func import ODEFunc
from models.STDEN.diffeq_solver import DiffeqSolver
from models.STDEN import utils
from data.graph_loader import load_graph
class EncoderAttrs:
"""编码器属性配置类"""
def __init__(self, config, adj_mx):
self.adj_mx = adj_mx
self.num_nodes = adj_mx.shape[0]
self.num_edges = (adj_mx > 0.).sum()
self.gcn_step = int(config.get('gcn_step', 2))
self.filter_type = config.get('filter_type', 'default')
self.num_rnn_layers = int(config.get('num_rnn_layers', 1))
self.rnn_units = int(config.get('rnn_units'))
self.latent_dim = int(config.get('latent_dim', 4))
class STDENModel(nn.Module, EncoderAttrs):
"""STDEN主模型时空微分方程网络"""
def __init__(self, config):
nn.Module.__init__(self)
adj_mx = load_graph(config)
EncoderAttrs.__init__(self, config['model'], adj_mx)
# 输入输出维度配置
self.input_dim = int(config['model'].get('input_dim', 1))
self.output_dim = int(config['model'].get('output_dim', 1))
# Node到Edge的转换层
self.node_to_edge = nn.Linear(self.num_nodes * self.input_dim, self.num_edges * self.input_dim)
# Edge到Node的转换层
self.edge_to_node = nn.Linear(self.num_edges * self.output_dim, self.num_nodes * self.output_dim)
# 初始化转换层权重
utils.init_network_weights(self.node_to_edge)
utils.init_network_weights(self.edge_to_node)
# 识别网络
self.encoder_z0 = Encoder_z0_RNN(config['model'], adj_mx)
model_kwargs = config['model']
# ODE求解器配置
self.n_traj_samples = int(model_kwargs.get('n_traj_samples', 1))
self.ode_method = model_kwargs.get('ode_method', 'dopri5')
self.atol = float(model_kwargs.get('odeint_atol', 1e-4))
self.rtol = float(model_kwargs.get('odeint_rtol', 1e-3))
self.num_gen_layer = int(model_kwargs.get('gen_layers', 1))
self.ode_gen_dim = int(model_kwargs.get('gen_dim', 64))
# 创建ODE函数和求解器
odefunc = ODEFunc(
self.ode_gen_dim, self.latent_dim, adj_mx,
self.gcn_step, self.num_nodes, filter_type=self.filter_type
)
self.diffeq_solver = DiffeqSolver(
odefunc, self.ode_method, self.latent_dim,
odeint_rtol=self.rtol, odeint_atol=self.atol
)
# 潜在特征保存设置
self.save_latent = bool(model_kwargs.get('save_latent', False))
self.latent_feat = None
# 解码器
self.horizon = int(model_kwargs.get('horizon', 1))
self.out_feat = int(model_kwargs.get('output_dim', 1))
self.decoder = Decoder(
self.out_feat, adj_mx, self.num_nodes, self.num_edges
)
def forward(self, inputs, labels=None, batches_seen=None):
"""
seq2seq前向传播
:param inputs: (batch_size, seq_len, num_nodes, input_dim) - 节点格式输入
:param labels: (batch_size, horizon, num_nodes, output_dim) - 节点格式标签
:param batches_seen: 已见批次数量
:return: outputs: (batch_size, horizon, num_nodes, output_dim) - 节点格式输出
"""
# 输入格式转换从node格式转换为edge格式
B, T, N, C = inputs.shape
inputs_node = inputs.view(T, B, N * C) # (T, B, N*C)
# 将node格式转换为edge格式
inputs_edge = self.node_to_edge(inputs_node) # (T, B, E*C)
# 编码初始潜在状态
first_point_mu, first_point_std = self.encoder_z0(inputs_edge)
# 采样轨迹
means_z0 = first_point_mu.repeat(self.n_traj_samples, 1, 1)
sigma_z0 = first_point_std.repeat(self.n_traj_samples, 1, 1)
first_point_enc = utils.sample_standard_gaussian(means_z0, sigma_z0)
# 时间步预测
time_steps_to_predict = torch.arange(start=0, end=self.horizon, step=1).float()
time_steps_to_predict = time_steps_to_predict / len(time_steps_to_predict)
# ODE求解
sol_ys, fe = self.diffeq_solver(first_point_enc, time_steps_to_predict)
if self.save_latent:
self.latent_feat = torch.mean(sol_ys.detach(), axis=1)
# 解码输出edge格式
outputs_edge = self.decoder(sol_ys) # (horizon, B, E*output_dim)
# 将edge格式转换回node格式
outputs_node = self.edge_to_node(outputs_edge) # (horizon, B, N*output_dim)
# 重塑为最终输出格式
outputs = outputs_node.view(self.horizon, B, N, self.output_dim)
outputs = outputs.transpose(0, 1) # (B, horizon, N, output_dim)
return outputs, fe
class Encoder_z0_RNN(nn.Module, EncoderAttrs):
"""RNN编码器将输入序列编码为初始潜在状态"""
def __init__(self, config, adj_mx):
nn.Module.__init__(self)
EncoderAttrs.__init__(self, config, adj_mx)
self.recg_type = config.get('recg_type', 'gru')
self.input_dim = int(config.get('input_dim', 1))
if self.recg_type == 'gru':
self.gru_rnn = GRU(self.input_dim, self.rnn_units)
else:
raise NotImplementedError("只支持'gru'识别网络")
# 隐藏状态到z0的映射
self.inv_grad = utils.graph_grad(adj_mx).transpose(-2, -1)
self.inv_grad[self.inv_grad != 0.] = 0.5
self.hiddens_to_z0 = nn.Sequential(
nn.Linear(self.rnn_units, 50),
nn.Tanh(),
nn.Linear(50, self.latent_dim * 2)
)
utils.init_network_weights(self.hiddens_to_z0)
def forward(self, inputs):
"""
编码器前向传播
:param inputs: (seq_len, batch_size, num_edges * input_dim)
:return: mean, std: (1, batch_size, latent_dim)
"""
seq_len, batch_size = inputs.size(0), inputs.size(1)
# 重塑输入并处理 - 现在输入是edge格式
inputs = inputs.reshape(seq_len, batch_size, self.num_edges, self.input_dim)
inputs = inputs.reshape(seq_len, batch_size * self.num_edges, self.input_dim)
# GRU处理
outputs, _ = self.gru_rnn(inputs)
last_output = outputs[-1]
# 重塑并转换维度 - 从edge格式转换回node格式
last_output = torch.reshape(last_output, (batch_size, self.num_edges, -1))
last_output = torch.transpose(last_output, -2, -1)
last_output = torch.matmul(last_output, self.inv_grad).transpose(-2, -1)
# 生成均值和标准差
mean, std = utils.split_last_dim(self.hiddens_to_z0(last_output))
mean = mean.reshape(batch_size, -1)
std = std.reshape(batch_size, -1).abs()
return mean.unsqueeze(0), std.unsqueeze(0)
class Decoder(nn.Module):
"""解码器:将潜在状态解码为输出"""
def __init__(self, output_dim, adj_mx, num_nodes, num_edges):
super(Decoder, self).__init__()
self.num_nodes = num_nodes
self.num_edges = num_edges
self.grap_grad = utils.graph_grad(adj_mx)
self.output_dim = output_dim
def forward(self, inputs):
"""
:param inputs: (horizon, n_traj_samples, batch_size, num_nodes * latent_dim)
:return: outputs: (horizon, batch_size, num_edges * output_dim)
"""
horizon, n_traj_samples, batch_size = inputs.size()[:3]
# 重塑输入
inputs = inputs.reshape(horizon, n_traj_samples, batch_size, self.num_nodes, -1).transpose(-2, -1)
latent_dim = inputs.size(-2)
# 图梯度变换:从节点到边
outputs = torch.matmul(inputs, self.grap_grad)
# 重塑并平均采样轨迹
outputs = outputs.reshape(horizon, n_traj_samples, batch_size, latent_dim, self.num_edges, self.output_dim)
outputs = torch.mean(torch.mean(outputs, axis=3), axis=1)
outputs = outputs.reshape(horizon, batch_size, -1)
return outputs