TrafficWheel/model/D2STGNN/D2STGNN.py

89 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from model.D2STGNN.diffusion_block.dif_block import DifBlock
from model.D2STGNN.inherent_block.inh_block import InhBlock
from model.D2STGNN.dynamic_graph_conv.dy_graph_conv import DynamicGraphConstructor
from model.D2STGNN.decouple.estimation_gate import EstimationGate
class DecoupleLayer(nn.Module):
def __init__(self, hidden_dim, fk_dim, args):
super().__init__()
self.est_gate = EstimationGate(node_emb_dim=args['node_hidden'], time_emb_dim=args['time_emb_dim'], hidden_dim=64)
# 只传递必要参数dy_graph会通过**args传递
self.dif_layer = DifBlock(hidden_dim, forecast_hidden_dim=fk_dim, **args)
self.inh_layer = InhBlock(hidden_dim, forecast_hidden_dim=fk_dim, **args)
def forward(self, x, dyn_graph, sta_graph=None, node_u=None, node_d=None, t_in_day=None, t_in_week=None):
gated_x = self.est_gate(node_u, node_d, t_in_day, t_in_week, x)
dif_back, dif_hidden = self.dif_layer(x, gated_x, dyn_graph, sta_graph)
inh_back, inh_hidden = self.inh_layer(dif_back)
return inh_back, dif_hidden, inh_hidden
class D2STGNN(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args # 保存args用于forward方法
self.num_nodes = args['num_nodes']
self.num_layers = args['num_layers']
self.hidden_dim = args['num_hidden']
self.forecast_dim = args['forecast_dim']
self.output_hidden = args['output_hidden']
self.output_dim = args['output_dim']
self.in_feat = args['input_dim']
self.embedding = nn.Linear(self.in_feat, self.hidden_dim)
self.T_i_D_emb = nn.Parameter(torch.empty(args.get('num_timesteps_in_day',288), args['time_emb_dim']))
self.D_i_W_emb = nn.Parameter(torch.empty(7, args['time_emb_dim']))
self.node_u = nn.Parameter(torch.empty(self.num_nodes, args['node_hidden']))
self.node_d = nn.Parameter(torch.empty(self.num_nodes, args['node_hidden']))
self.layers = nn.ModuleList([DecoupleLayer(self.hidden_dim, self.forecast_dim, args) for _ in range(self.num_layers)])
if args.get('dy_graph', False):
self.dynamic_graph_constructor = DynamicGraphConstructor(**args)
self.out_fc1 = nn.Linear(self.forecast_dim, self.output_hidden)
self.out_fc2 = nn.Linear(self.output_hidden, args['gap'] * args['output_dim'])
self._reset_parameters()
def _reset_parameters(self):
for p in [self.node_u, self.node_d, self.T_i_D_emb, self.D_i_W_emb]:
nn.init.xavier_uniform_(p)
def _prepare_inputs(self, x):
node_u, node_d = self.node_u, self.node_d
t_in_day = self.T_i_D_emb[(x[:, :, :, -2]*self.T_i_D_emb.size(0)).long()]
t_in_week = self.D_i_W_emb[x[:, :, :, -1].long()]
return x[:, :, :, :-2], node_u, node_d, t_in_day, t_in_week
def _graph_constructor(self, node_u, node_d, x, t_in_day, t_in_week):
# 只生成动态图,去除静态图
dyn_graph = self.dynamic_graph_constructor(node_u=node_u, node_d=node_d, history_data=x, time_in_day_feat=t_in_day, day_in_week_feat=t_in_week) if hasattr(self, 'dynamic_graph_constructor') else []
return [], dyn_graph
def forward(self, x):
x, node_u, node_d, t_in_day, t_in_week = self._prepare_inputs(x)
sta_graph, dyn_graph = self._graph_constructor(node_u, node_d, x, t_in_day, t_in_week)
x = self.embedding(x)
dif_hidden_list, inh_hidden_list = [], []
backcast = x
for layer in self.layers:
backcast, dif_hidden, inh_hidden = layer(backcast, dyn_graph, sta_graph, node_u, node_d, t_in_day, t_in_week)
dif_hidden_list.append(dif_hidden)
inh_hidden_list.append(inh_hidden)
forecast_hidden = sum(dif_hidden_list) + sum(inh_hidden_list)
# 调整输出形状,使其与标签匹配
forecast = self.out_fc1(F.relu(forecast_hidden))
forecast = F.relu(forecast)
forecast = self.out_fc2(forecast)
# 确保输出维度正确
if forecast.size(-1) != self.args['output_dim']:
forecast = forecast[..., :self.args['output_dim']]
# 确保时间步长正确
if forecast.size(1) != self.args['horizon']:
# 如果时间步长不足,进行插值或重复
forecast = forecast.repeat(1, self.args['horizon'] // forecast.size(1) + 1, 1, 1)[:, :self.args['horizon'], :, :]
return forecast