22 lines
685 B
Python
22 lines
685 B
Python
import torch
|
||
import torch.nn as nn
|
||
|
||
class Mask(nn.Module):
|
||
def __init__(self, **model_args):
|
||
super().__init__()
|
||
self.mask = model_args.get('adjs', None) # 允许adjs为None
|
||
|
||
def _mask(self, index, adj):
|
||
if self.mask is None or len(self.mask) == 0:
|
||
# 如果没有预定义的邻接矩阵,直接返回原始的adj
|
||
return adj
|
||
else:
|
||
mask = self.mask[index] + torch.ones_like(self.mask[index]) * 1e-7
|
||
return mask.to(adj.device) * adj
|
||
|
||
def forward(self, adj):
|
||
result = []
|
||
for index, _ in enumerate(adj):
|
||
result.append(self._mask(index, _))
|
||
return result
|