89 lines
4.4 KiB
Python
89 lines
4.4 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from model.D2STGNN.diffusion_block.dif_block import DifBlock
|
||
from model.D2STGNN.inherent_block.inh_block import InhBlock
|
||
from model.D2STGNN.dynamic_graph_conv.dy_graph_conv import DynamicGraphConstructor
|
||
from model.D2STGNN.decouple.estimation_gate import EstimationGate
|
||
|
||
class DecoupleLayer(nn.Module):
|
||
def __init__(self, hidden_dim, fk_dim, args):
|
||
super().__init__()
|
||
self.est_gate = EstimationGate(node_emb_dim=args['node_hidden'], time_emb_dim=args['time_emb_dim'], hidden_dim=64)
|
||
# 只传递必要参数,dy_graph会通过**args传递
|
||
self.dif_layer = DifBlock(hidden_dim, forecast_hidden_dim=fk_dim, **args)
|
||
self.inh_layer = InhBlock(hidden_dim, forecast_hidden_dim=fk_dim, **args)
|
||
|
||
def forward(self, x, dyn_graph, sta_graph=None, node_u=None, node_d=None, t_in_day=None, t_in_week=None):
|
||
gated_x = self.est_gate(node_u, node_d, t_in_day, t_in_week, x)
|
||
dif_back, dif_hidden = self.dif_layer(x, gated_x, dyn_graph, sta_graph)
|
||
inh_back, inh_hidden = self.inh_layer(dif_back)
|
||
return inh_back, dif_hidden, inh_hidden
|
||
|
||
class D2STGNN(nn.Module):
|
||
def __init__(self, args):
|
||
super().__init__()
|
||
self.args = args # 保存args用于forward方法
|
||
self.num_nodes = args['num_nodes']
|
||
self.num_layers = args['num_layers']
|
||
self.hidden_dim = args['num_hidden']
|
||
self.forecast_dim = args['forecast_dim']
|
||
self.output_hidden = args['output_hidden']
|
||
self.output_dim = args['output_dim']
|
||
self.in_feat = args['input_dim']
|
||
|
||
self.embedding = nn.Linear(self.in_feat, self.hidden_dim)
|
||
self.T_i_D_emb = nn.Parameter(torch.empty(args.get('num_timesteps_in_day',288), args['time_emb_dim']))
|
||
self.D_i_W_emb = nn.Parameter(torch.empty(7, args['time_emb_dim']))
|
||
self.node_u = nn.Parameter(torch.empty(self.num_nodes, args['node_hidden']))
|
||
self.node_d = nn.Parameter(torch.empty(self.num_nodes, args['node_hidden']))
|
||
|
||
self.layers = nn.ModuleList([DecoupleLayer(self.hidden_dim, self.forecast_dim, args) for _ in range(self.num_layers)])
|
||
if args.get('dy_graph', False):
|
||
self.dynamic_graph_constructor = DynamicGraphConstructor(**args)
|
||
|
||
self.out_fc1 = nn.Linear(self.forecast_dim, self.output_hidden)
|
||
self.out_fc2 = nn.Linear(self.output_hidden, args['gap'] * args['output_dim'])
|
||
self._reset_parameters()
|
||
|
||
def _reset_parameters(self):
|
||
for p in [self.node_u, self.node_d, self.T_i_D_emb, self.D_i_W_emb]:
|
||
nn.init.xavier_uniform_(p)
|
||
|
||
def _prepare_inputs(self, x):
|
||
node_u, node_d = self.node_u, self.node_d
|
||
t_in_day = self.T_i_D_emb[(x[:, :, :, -2]*self.T_i_D_emb.size(0)).long()]
|
||
t_in_week = self.D_i_W_emb[x[:, :, :, -1].long()]
|
||
return x[:, :, :, :-2], node_u, node_d, t_in_day, t_in_week
|
||
|
||
def _graph_constructor(self, node_u, node_d, x, t_in_day, t_in_week):
|
||
# 只生成动态图,去除静态图
|
||
dyn_graph = self.dynamic_graph_constructor(node_u=node_u, node_d=node_d, history_data=x, time_in_day_feat=t_in_day, day_in_week_feat=t_in_week) if hasattr(self, 'dynamic_graph_constructor') else []
|
||
return [], dyn_graph
|
||
|
||
def forward(self, x):
|
||
x, node_u, node_d, t_in_day, t_in_week = self._prepare_inputs(x)
|
||
sta_graph, dyn_graph = self._graph_constructor(node_u, node_d, x, t_in_day, t_in_week)
|
||
x = self.embedding(x)
|
||
|
||
dif_hidden_list, inh_hidden_list = [], []
|
||
backcast = x
|
||
for layer in self.layers:
|
||
backcast, dif_hidden, inh_hidden = layer(backcast, dyn_graph, sta_graph, node_u, node_d, t_in_day, t_in_week)
|
||
dif_hidden_list.append(dif_hidden)
|
||
inh_hidden_list.append(inh_hidden)
|
||
|
||
forecast_hidden = sum(dif_hidden_list) + sum(inh_hidden_list)
|
||
# 调整输出形状,使其与标签匹配
|
||
forecast = self.out_fc1(F.relu(forecast_hidden))
|
||
forecast = F.relu(forecast)
|
||
forecast = self.out_fc2(forecast)
|
||
# 确保输出维度正确
|
||
if forecast.size(-1) != self.args['output_dim']:
|
||
forecast = forecast[..., :self.args['output_dim']]
|
||
# 确保时间步长正确
|
||
if forecast.size(1) != self.args['horizon']:
|
||
# 如果时间步长不足,进行插值或重复
|
||
forecast = forecast.repeat(1, self.args['horizon'] // forecast.size(1) + 1, 1, 1)[:, :self.args['horizon'], :, :]
|
||
return forecast
|