FS-TFP/federatedscope/cl/fedgc/utils.py

65 lines
2.5 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
def norm(w):
return torch.norm(torch.cat([v.flatten() for v in w.values()])).item()
class global_NT_xentloss(nn.Module):
r"""
NT_xentloss definition adapted from https://github.com/PatrickHua/SimSiam
Arguments:
z1 (torch.tensor): the embedding of model .
z2 (torch.tensor): the embedding of model using another augmentation.
returns:
loss: the NT_xentloss loss for this batch data
:rtype:
torch.FloatTensor
"""
def __init__(self, temperature=0.1, device=torch.device("cpu")):
super(global_NT_xentloss, self).__init__()
self.temperature = temperature
self.device = device
def forward(self, z1, z2, others_z2=[]):
N, Z = z1.shape
z1, z2 = z1.to(self.device), z2.to(self.device)
representations = torch.cat([z1, z2], dim=0)
similarity_matrix = F.cosine_similarity(representations.unsqueeze(1),
representations.unsqueeze(0),
dim=-1)
l_pos = torch.diag(similarity_matrix, N)
r_pos = torch.diag(similarity_matrix, -N)
positives = torch.cat([l_pos, r_pos]).view(2 * N, 1)
diag = torch.eye(2 * N, dtype=torch.bool, device=self.device)
diag[N:, :N] = diag[:N, N:] = diag[:N, :N]
negatives = similarity_matrix[~diag].view(2 * N, -1)
if len(others_z2) != 0:
for z2_ in others_z2:
z2_ = z2_.detach().to(self.device)
N2, Z2 = z2_.shape
representations = torch.cat([z1, z2_], dim=0)
similarity_matrix = F.cosine_similarity(
representations.unsqueeze(1),
representations.unsqueeze(0),
dim=-1)
mask = torch.zeros_like(similarity_matrix,
dtype=torch.bool,
device=self.device)
mask[N:, :N] = True
mask[:N, N:] = True
negatives_other = similarity_matrix[mask].view(2 * N, -1)
negatives = torch.cat([negatives, negatives_other], dim=1)
logits = torch.cat([positives, negatives], dim=1) / self.temperature
labels = torch.zeros(2 * N, dtype=torch.int64,
device=self.device) # scalar label per sample
loss = F.cross_entropy(logits, labels, reduction='sum') / (2 * N)
return loss