65 lines
2.5 KiB
Python
65 lines
2.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def norm(w):
|
|
return torch.norm(torch.cat([v.flatten() for v in w.values()])).item()
|
|
|
|
|
|
class global_NT_xentloss(nn.Module):
|
|
r"""
|
|
NT_xentloss definition adapted from https://github.com/PatrickHua/SimSiam
|
|
Arguments:
|
|
z1 (torch.tensor): the embedding of model .
|
|
z2 (torch.tensor): the embedding of model using another augmentation.
|
|
returns:
|
|
loss: the NT_xentloss loss for this batch data
|
|
:rtype:
|
|
torch.FloatTensor
|
|
"""
|
|
def __init__(self, temperature=0.1, device=torch.device("cpu")):
|
|
super(global_NT_xentloss, self).__init__()
|
|
self.temperature = temperature
|
|
self.device = device
|
|
|
|
def forward(self, z1, z2, others_z2=[]):
|
|
N, Z = z1.shape
|
|
z1, z2 = z1.to(self.device), z2.to(self.device)
|
|
representations = torch.cat([z1, z2], dim=0)
|
|
similarity_matrix = F.cosine_similarity(representations.unsqueeze(1),
|
|
representations.unsqueeze(0),
|
|
dim=-1)
|
|
|
|
l_pos = torch.diag(similarity_matrix, N)
|
|
r_pos = torch.diag(similarity_matrix, -N)
|
|
positives = torch.cat([l_pos, r_pos]).view(2 * N, 1)
|
|
|
|
diag = torch.eye(2 * N, dtype=torch.bool, device=self.device)
|
|
diag[N:, :N] = diag[:N, N:] = diag[:N, :N]
|
|
negatives = similarity_matrix[~diag].view(2 * N, -1)
|
|
|
|
if len(others_z2) != 0:
|
|
for z2_ in others_z2:
|
|
z2_ = z2_.detach().to(self.device)
|
|
N2, Z2 = z2_.shape
|
|
representations = torch.cat([z1, z2_], dim=0)
|
|
similarity_matrix = F.cosine_similarity(
|
|
representations.unsqueeze(1),
|
|
representations.unsqueeze(0),
|
|
dim=-1)
|
|
mask = torch.zeros_like(similarity_matrix,
|
|
dtype=torch.bool,
|
|
device=self.device)
|
|
mask[N:, :N] = True
|
|
mask[:N, N:] = True
|
|
negatives_other = similarity_matrix[mask].view(2 * N, -1)
|
|
negatives = torch.cat([negatives, negatives_other], dim=1)
|
|
|
|
logits = torch.cat([positives, negatives], dim=1) / self.temperature
|
|
labels = torch.zeros(2 * N, dtype=torch.int64,
|
|
device=self.device) # scalar label per sample
|
|
loss = F.cross_entropy(logits, labels, reduction='sum') / (2 * N)
|
|
|
|
return loss
|