TrafficWheel/model/D2STGNN/dynamic_graph_conv/utils/mask.py

22 lines
685 B
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
class Mask(nn.Module):
def __init__(self, **model_args):
super().__init__()
self.mask = model_args.get('adjs', None) # 允许adjs为None
def _mask(self, index, adj):
if self.mask is None or len(self.mask) == 0:
# 如果没有预定义的邻接矩阵直接返回原始的adj
return adj
else:
mask = self.mask[index] + torch.ones_like(self.mask[index]) * 1e-7
return mask.to(adj.device) * adj
def forward(self, adj):
result = []
for index, _ in enumerate(adj):
result.append(self._mask(index, _))
return result