import torch import torch.nn as nn import torch.nn.functional as F from model.D2STGNN.diffusion_block.dif_block import DifBlock from model.D2STGNN.inherent_block.inh_block import InhBlock from model.D2STGNN.dynamic_graph_conv.dy_graph_conv import DynamicGraphConstructor from model.D2STGNN.decouple.estimation_gate import EstimationGate class DecoupleLayer(nn.Module): def __init__(self, hidden_dim, fk_dim, args): super().__init__() self.est_gate = EstimationGate(node_emb_dim=args['node_hidden'], time_emb_dim=args['time_emb_dim'], hidden_dim=64) # 只传递必要参数,dy_graph会通过**args传递 self.dif_layer = DifBlock(hidden_dim, forecast_hidden_dim=fk_dim, **args) self.inh_layer = InhBlock(hidden_dim, forecast_hidden_dim=fk_dim, **args) def forward(self, x, dyn_graph, sta_graph=None, node_u=None, node_d=None, t_in_day=None, t_in_week=None): gated_x = self.est_gate(node_u, node_d, t_in_day, t_in_week, x) dif_back, dif_hidden = self.dif_layer(x, gated_x, dyn_graph, sta_graph) inh_back, inh_hidden = self.inh_layer(dif_back) return inh_back, dif_hidden, inh_hidden class D2STGNN(nn.Module): def __init__(self, args): super().__init__() self.args = args # 保存args用于forward方法 self.num_nodes = args['num_nodes'] self.num_layers = args['num_layers'] self.hidden_dim = args['num_hidden'] self.forecast_dim = args['forecast_dim'] self.output_hidden = args['output_hidden'] self.output_dim = args['output_dim'] self.in_feat = args['input_dim'] self.embedding = nn.Linear(self.in_feat, self.hidden_dim) self.T_i_D_emb = nn.Parameter(torch.empty(args.get('num_timesteps_in_day',288), args['time_emb_dim'])) self.D_i_W_emb = nn.Parameter(torch.empty(7, args['time_emb_dim'])) self.node_u = nn.Parameter(torch.empty(self.num_nodes, args['node_hidden'])) self.node_d = nn.Parameter(torch.empty(self.num_nodes, args['node_hidden'])) self.layers = nn.ModuleList([DecoupleLayer(self.hidden_dim, self.forecast_dim, args) for _ in range(self.num_layers)]) if args.get('dy_graph', False): self.dynamic_graph_constructor = DynamicGraphConstructor(**args) self.out_fc1 = nn.Linear(self.forecast_dim, self.output_hidden) self.out_fc2 = nn.Linear(self.output_hidden, args['gap'] * args['output_dim']) self._reset_parameters() def _reset_parameters(self): for p in [self.node_u, self.node_d, self.T_i_D_emb, self.D_i_W_emb]: nn.init.xavier_uniform_(p) def _prepare_inputs(self, x): node_u, node_d = self.node_u, self.node_d t_in_day = self.T_i_D_emb[(x[:, :, :, -2]*self.T_i_D_emb.size(0)).long()] t_in_week = self.D_i_W_emb[x[:, :, :, -1].long()] return x[:, :, :, :-2], node_u, node_d, t_in_day, t_in_week def _graph_constructor(self, node_u, node_d, x, t_in_day, t_in_week): # 只生成动态图,去除静态图 dyn_graph = self.dynamic_graph_constructor(node_u=node_u, node_d=node_d, history_data=x, time_in_day_feat=t_in_day, day_in_week_feat=t_in_week) if hasattr(self, 'dynamic_graph_constructor') else [] return [], dyn_graph def forward(self, x): x, node_u, node_d, t_in_day, t_in_week = self._prepare_inputs(x) sta_graph, dyn_graph = self._graph_constructor(node_u, node_d, x, t_in_day, t_in_week) x = self.embedding(x) dif_hidden_list, inh_hidden_list = [], [] backcast = x for layer in self.layers: backcast, dif_hidden, inh_hidden = layer(backcast, dyn_graph, sta_graph, node_u, node_d, t_in_day, t_in_week) dif_hidden_list.append(dif_hidden) inh_hidden_list.append(inh_hidden) forecast_hidden = sum(dif_hidden_list) + sum(inh_hidden_list) # 调整输出形状,使其与标签匹配 forecast = self.out_fc1(F.relu(forecast_hidden)) forecast = F.relu(forecast) forecast = self.out_fc2(forecast) # 确保输出维度正确 if forecast.size(-1) != self.args['output_dim']: forecast = forecast[..., :self.args['output_dim']] # 确保时间步长正确 if forecast.size(1) != self.args['horizon']: # 如果时间步长不足,进行插值或重复 forecast = forecast.repeat(1, self.args['horizon'] // forecast.size(1) + 1, 1, 1)[:, :self.args['horizon'], :, :] return forecast