import torch.nn as nn from model.D2STGNN.dynamic_graph_conv.utils.distance import DistanceFunction from model.D2STGNN.dynamic_graph_conv.utils.mask import Mask from model.D2STGNN.dynamic_graph_conv.utils.normalizer import Normalizer, MultiOrder class DynamicGraphConstructor(nn.Module): def __init__(self, **model_args): super().__init__() # model args self.k_s = model_args['k_s'] # spatial order self.k_t = model_args['k_t'] # temporal kernel size # hidden dimension of self.hidden_dim = model_args['num_hidden'] # trainable node embedding dimension self.node_dim = model_args['node_hidden'] self.distance_function = DistanceFunction(**model_args) self.mask = Mask(**model_args) self.normalizer = Normalizer() self.multi_order = MultiOrder(order=self.k_s) def st_localization(self, graph_ordered): st_local_graph = [] for modality_i in graph_ordered: for k_order_graph in modality_i: k_order_graph = k_order_graph.unsqueeze( -2).expand(-1, -1, self.k_t, -1) k_order_graph = k_order_graph.reshape( k_order_graph.shape[0], k_order_graph.shape[1], k_order_graph.shape[2] * k_order_graph.shape[3]) st_local_graph.append(k_order_graph) return st_local_graph def forward(self, **inputs): """Dynamic graph learning module. Args: history_data (torch.Tensor): input data with shape (B, L, N, D) node_embedding_u (torch.Parameter): node embedding E_u node_embedding_d (torch.Parameter): node embedding E_d time_in_day_feat (torch.Parameter): time embedding T_D day_in_week_feat (torch.Parameter): time embedding T_W Returns: list: dynamic graphs """ X = inputs['history_data'] E_d = inputs['node_d'] # 参数名改为node_d E_u = inputs['node_u'] # 参数名改为node_u T_D = inputs['time_in_day_feat'] D_W = inputs['day_in_week_feat'] # distance calculation dist_mx = self.distance_function(X, E_d, E_u, T_D, D_W) # mask dist_mx = self.mask(dist_mx) # normalization dist_mx = self.normalizer(dist_mx) # multi order mul_mx = self.multi_order(dist_mx) # spatial temporal localization dynamic_graphs = self.st_localization(mul_mx) return dynamic_graphs