TrafficWheel/model/D2STGNN/dynamic_graph_conv/dy_graph_conv.py

67 lines
2.5 KiB
Python

import torch.nn as nn
from model.D2STGNN.dynamic_graph_conv.utils.distance import DistanceFunction
from model.D2STGNN.dynamic_graph_conv.utils.mask import Mask
from model.D2STGNN.dynamic_graph_conv.utils.normalizer import Normalizer, MultiOrder
class DynamicGraphConstructor(nn.Module):
def __init__(self, **model_args):
super().__init__()
# model args
self.k_s = model_args['k_s'] # spatial order
self.k_t = model_args['k_t'] # temporal kernel size
# hidden dimension of
self.hidden_dim = model_args['num_hidden']
# trainable node embedding dimension
self.node_dim = model_args['node_hidden']
self.distance_function = DistanceFunction(**model_args)
self.mask = Mask(**model_args)
self.normalizer = Normalizer()
self.multi_order = MultiOrder(order=self.k_s)
def st_localization(self, graph_ordered):
st_local_graph = []
for modality_i in graph_ordered:
for k_order_graph in modality_i:
k_order_graph = k_order_graph.unsqueeze(
-2).expand(-1, -1, self.k_t, -1)
k_order_graph = k_order_graph.reshape(
k_order_graph.shape[0], k_order_graph.shape[1], k_order_graph.shape[2] * k_order_graph.shape[3])
st_local_graph.append(k_order_graph)
return st_local_graph
def forward(self, **inputs):
"""Dynamic graph learning module.
Args:
history_data (torch.Tensor): input data with shape (B, L, N, D)
node_embedding_u (torch.Parameter): node embedding E_u
node_embedding_d (torch.Parameter): node embedding E_d
time_in_day_feat (torch.Parameter): time embedding T_D
day_in_week_feat (torch.Parameter): time embedding T_W
Returns:
list: dynamic graphs
"""
X = inputs['history_data']
E_d = inputs['node_d'] # 参数名改为node_d
E_u = inputs['node_u'] # 参数名改为node_u
T_D = inputs['time_in_day_feat']
D_W = inputs['day_in_week_feat']
# distance calculation
dist_mx = self.distance_function(X, E_d, E_u, T_D, D_W)
# mask
dist_mx = self.mask(dist_mx)
# normalization
dist_mx = self.normalizer(dist_mx)
# multi order
mul_mx = self.multi_order(dist_mx)
# spatial temporal localization
dynamic_graphs = self.st_localization(mul_mx)
return dynamic_graphs