67 lines
2.5 KiB
Python
67 lines
2.5 KiB
Python
import torch.nn as nn
|
|
|
|
from model.D2STGNN.dynamic_graph_conv.utils.distance import DistanceFunction
|
|
from model.D2STGNN.dynamic_graph_conv.utils.mask import Mask
|
|
|
|
from model.D2STGNN.dynamic_graph_conv.utils.normalizer import Normalizer, MultiOrder
|
|
|
|
|
|
class DynamicGraphConstructor(nn.Module):
|
|
def __init__(self, **model_args):
|
|
super().__init__()
|
|
# model args
|
|
self.k_s = model_args['k_s'] # spatial order
|
|
self.k_t = model_args['k_t'] # temporal kernel size
|
|
# hidden dimension of
|
|
self.hidden_dim = model_args['num_hidden']
|
|
# trainable node embedding dimension
|
|
self.node_dim = model_args['node_hidden']
|
|
|
|
self.distance_function = DistanceFunction(**model_args)
|
|
self.mask = Mask(**model_args)
|
|
self.normalizer = Normalizer()
|
|
self.multi_order = MultiOrder(order=self.k_s)
|
|
|
|
def st_localization(self, graph_ordered):
|
|
st_local_graph = []
|
|
for modality_i in graph_ordered:
|
|
for k_order_graph in modality_i:
|
|
k_order_graph = k_order_graph.unsqueeze(
|
|
-2).expand(-1, -1, self.k_t, -1)
|
|
k_order_graph = k_order_graph.reshape(
|
|
k_order_graph.shape[0], k_order_graph.shape[1], k_order_graph.shape[2] * k_order_graph.shape[3])
|
|
st_local_graph.append(k_order_graph)
|
|
return st_local_graph
|
|
|
|
def forward(self, **inputs):
|
|
"""Dynamic graph learning module.
|
|
|
|
Args:
|
|
history_data (torch.Tensor): input data with shape (B, L, N, D)
|
|
node_embedding_u (torch.Parameter): node embedding E_u
|
|
node_embedding_d (torch.Parameter): node embedding E_d
|
|
time_in_day_feat (torch.Parameter): time embedding T_D
|
|
day_in_week_feat (torch.Parameter): time embedding T_W
|
|
|
|
Returns:
|
|
list: dynamic graphs
|
|
"""
|
|
|
|
X = inputs['history_data']
|
|
E_d = inputs['node_d'] # 参数名改为node_d
|
|
E_u = inputs['node_u'] # 参数名改为node_u
|
|
T_D = inputs['time_in_day_feat']
|
|
D_W = inputs['day_in_week_feat']
|
|
# distance calculation
|
|
dist_mx = self.distance_function(X, E_d, E_u, T_D, D_W)
|
|
# mask
|
|
dist_mx = self.mask(dist_mx)
|
|
# normalization
|
|
dist_mx = self.normalizer(dist_mx)
|
|
# multi order
|
|
mul_mx = self.multi_order(dist_mx)
|
|
# spatial temporal localization
|
|
dynamic_graphs = self.st_localization(mul_mx)
|
|
|
|
return dynamic_graphs
|