import torch import torch.nn as nn class Mask(nn.Module): def __init__(self, **model_args): super().__init__() self.mask = model_args.get('adjs', None) # 允许adjs为None def _mask(self, index, adj): if self.mask is None or len(self.mask) == 0: # 如果没有预定义的邻接矩阵,直接返回原始的adj return adj else: mask = self.mask[index] + torch.ones_like(self.mask[index]) * 1e-7 return mask.to(adj.device) * adj def forward(self, adj): result = [] for index, _ in enumerate(adj): result.append(self._mask(index, _)) return result