TrafficWheel/model/D2STGNN/dynamic_graph_conv/utils/normalizer.py

43 lines
1.2 KiB
Python

import torch
import torch.nn as nn
def remove_nan_inf(x):
"""移除张量中的nan和inf值"""
x = torch.where(torch.isnan(x) | torch.isinf(x), torch.zeros_like(x), x)
return x
class Normalizer(nn.Module):
def __init__(self):
super().__init__()
def _norm(self, graph):
degree = torch.sum(graph, dim=2)
degree = remove_nan_inf(1 / degree)
degree = torch.diag_embed(degree)
normed_graph = torch.bmm(degree, graph)
return normed_graph
def forward(self, adj):
return [self._norm(_) for _ in adj]
class MultiOrder(nn.Module):
def __init__(self, order=2):
super().__init__()
self.order = order
def _multi_order(self, graph):
graph_ordered = []
k_1_order = graph # 1 order
mask = torch.eye(graph.shape[1]).to(graph.device)
mask = 1 - mask
graph_ordered.append(k_1_order * mask)
for k in range(2, self.order+1): # e.g., order = 3, k=[2, 3]; order = 2, k=[2]
k_1_order = torch.matmul(k_1_order, graph)
graph_ordered.append(k_1_order * mask)
return graph_ordered
def forward(self, adj):
return [self._multi_order(_) for _ in adj]