Project-I/models/STDEN/stden_model.py

182 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
from torch.nn.modules.rnn import GRU
from models.STDEN.ode_func import ODEFunc
from models.STDEN.diffeq_solver import DiffeqSolver
from models.STDEN import utils
from data.graph_loader import load_graph
class EncoderAttrs:
"""编码器属性配置类"""
def __init__(self, config, adj_mx):
self.adj_mx = adj_mx
self.num_nodes = adj_mx.shape[0]
self.num_edges = (adj_mx > 0.).sum()
self.gcn_step = int(config.get('gcn_step', 2))
self.filter_type = config.get('filter_type', 'default')
self.num_rnn_layers = int(config.get('num_rnn_layers', 1))
self.rnn_units = int(config.get('rnn_units'))
self.latent_dim = int(config.get('latent_dim', 4))
class STDENModel(nn.Module, EncoderAttrs):
"""STDEN主模型时空微分方程网络"""
def __init__(self, config):
nn.Module.__init__(self)
adj_mx = load_graph(config)
EncoderAttrs.__init__(self, config['model'], adj_mx)
# 识别网络
self.encoder_z0 = Encoder_z0_RNN(config['model'], adj_mx)
model_kwargs = config['model']
# ODE求解器配置
self.n_traj_samples = int(model_kwargs.get('n_traj_samples', 1))
self.ode_method = model_kwargs.get('ode_method', 'dopri5')
self.atol = float(model_kwargs.get('odeint_atol', 1e-4))
self.rtol = float(model_kwargs.get('odeint_rtol', 1e-3))
self.num_gen_layer = int(model_kwargs.get('gen_layers', 1))
self.ode_gen_dim = int(model_kwargs.get('gen_dim', 64))
# 创建ODE函数和求解器
odefunc = ODEFunc(
self.ode_gen_dim, self.latent_dim, adj_mx,
self.gcn_step, self.num_nodes, filter_type=self.filter_type
)
self.diffeq_solver = DiffeqSolver(
odefunc, self.ode_method, self.latent_dim,
odeint_rtol=self.rtol, odeint_atol=self.atol
)
# 潜在特征保存设置
self.save_latent = bool(model_kwargs.get('save_latent', False))
self.latent_feat = None
# 解码器
self.horizon = int(model_kwargs.get('horizon', 1))
self.out_feat = int(model_kwargs.get('output_dim', 1))
self.decoder = Decoder(
self.out_feat, adj_mx, self.num_nodes, self.num_edges
)
def forward(self, inputs, labels=None, batches_seen=None):
"""
seq2seq前向传播
:param inputs: (seq_len, batch_size, num_edges * input_dim)
:param labels: (horizon, batch_size, num_edges * output_dim)
:param batches_seen: 已见批次数量
:return: outputs: (horizon, batch_size, num_edges * output_dim)
"""
# 编码初始潜在状态
B, T, N, C = inputs.shape
inputs = inputs.view(T, B, N * C)
first_point_mu, first_point_std = self.encoder_z0(inputs)
# 采样轨迹
means_z0 = first_point_mu.repeat(self.n_traj_samples, 1, 1)
sigma_z0 = first_point_std.repeat(self.n_traj_samples, 1, 1)
first_point_enc = utils.sample_standard_gaussian(means_z0, sigma_z0)
# 时间步预测
time_steps_to_predict = torch.arange(start=0, end=self.horizon, step=1).float()
time_steps_to_predict = time_steps_to_predict / len(time_steps_to_predict)
# ODE求解
sol_ys, fe = self.diffeq_solver(first_point_enc, time_steps_to_predict)
if self.save_latent:
self.latent_feat = torch.mean(sol_ys.detach(), axis=1)
# 解码输出
outputs = self.decoder(sol_ys)
outputs = outputs.view(B, T, N, C)
return outputs, fe
class Encoder_z0_RNN(nn.Module, EncoderAttrs):
"""RNN编码器将输入序列编码为初始潜在状态"""
def __init__(self, config, adj_mx):
nn.Module.__init__(self)
EncoderAttrs.__init__(self, config, adj_mx)
self.recg_type = config.get('recg_type', 'gru')
self.input_dim = int(config.get('input_dim', 1))
if self.recg_type == 'gru':
self.gru_rnn = GRU(self.input_dim, self.rnn_units)
else:
raise NotImplementedError("只支持'gru'识别网络")
# 隐藏状态到z0的映射
self.inv_grad = utils.graph_grad(adj_mx).transpose(-2, -1)
self.inv_grad[self.inv_grad != 0.] = 0.5
self.hiddens_to_z0 = nn.Sequential(
nn.Linear(self.rnn_units, 50),
nn.Tanh(),
nn.Linear(50, self.latent_dim * 2)
)
utils.init_network_weights(self.hiddens_to_z0)
def forward(self, inputs):
"""
编码器前向传播
:param inputs: (seq_len, batch_size, num_edges * input_dim)
:return: mean, std: (1, batch_size, latent_dim)
"""
seq_len, batch_size = inputs.size(0), inputs.size(1)
# 重塑输入并处理
inputs = inputs.reshape(seq_len, batch_size, self.num_nodes, self.input_dim)
inputs = inputs.reshape(seq_len, batch_size * self.num_nodes, self.input_dim)
# GRU处理
outputs, _ = self.gru_rnn(inputs)
last_output = outputs[-1]
# 重塑并转换维度
last_output = torch.reshape(last_output, (batch_size, self.num_nodes, -1))
last_output = torch.transpose(last_output, (-2, -1))
last_output = torch.matmul(last_output, self.inv_grad).transpose(-2, -1)
# 生成均值和标准差
mean, std = utils.split_last_dim(self.hiddens_to_z0(last_output))
mean = mean.reshape(batch_size, -1)
std = std.reshape(batch_size, -1).abs()
return mean.unsqueeze(0), std.unsqueeze(0)
class Decoder(nn.Module):
"""解码器:将潜在状态解码为输出"""
def __init__(self, output_dim, adj_mx, num_nodes, num_edges):
super(Decoder, self).__init__()
self.num_nodes = num_nodes
self.num_edges = num_edges
self.grap_grad = utils.graph_grad(adj_mx)
self.output_dim = output_dim
def forward(self, inputs):
"""
:param inputs: (horizon, n_traj_samples, batch_size, num_nodes * latent_dim)
:return: outputs: (horizon, batch_size, num_edges * output_dim)
"""
horizon, n_traj_samples, batch_size = inputs.size()[:3]
# 重塑输入
inputs = inputs.reshape(horizon, n_traj_samples, batch_size, self.num_nodes, -1).transpose(-2, -1)
latent_dim = inputs.size(-2)
# 图梯度变换:从节点到边
outputs = torch.matmul(inputs, self.grap_grad)
# 重塑并平均采样轨迹
outputs = outputs.reshape(horizon, n_traj_samples, batch_size, latent_dim, self.num_nodes, self.output_dim)
outputs = torch.mean(torch.mean(outputs, axis=3), axis=1)
outputs = outputs.reshape(horizon, batch_size, -1)
return outputs