FS-TFP/federatedscope/gfl/loss/greedy_loss.py

71 lines
3.1 KiB
Python

import numpy as np
import torch
import torch.nn.functional as F
def GreedyLoss(pred_feats, true_feats, pred_missing, true_missing, num_pred):
r"""Greedy loss is a loss function of cacluating the MSE loss for the feature.
https://proceedings.neurips.cc//paper/2021/file/ \
34adeb8e3242824038aa65460a47c29e-Paper.pdf
Fedsageplus models from the "Subgraph Federated Learning with Missing
Neighbor Generation" (FedSage+) paper, in NeurIPS'21
Source: https://github.com/zkhku/fedsage
Arguments:
pred_feats (torch.Tensor): generated missing features
true_feats (torch.Tensor): real missing features
pred_missing (torch.Tensor): number of predicted missing node
true_missing (torch.Tensor): number of missing node
num_pred (int): hyperparameters which limit the maximum value of the \
prediction
:returns:
loss : the Greedy Loss
:rtype:
torch.FloatTensor
"""
CUDA, device = (pred_feats.device.type != 'cpu'), pred_feats.device
if CUDA:
true_missing = true_missing.cpu()
pred_missing = pred_missing.cpu()
loss = torch.zeros(pred_feats.shape)
if CUDA:
loss = loss.to(device)
pred_len = len(pred_feats)
pred_missing_np = np.round(
pred_missing.detach().numpy()).reshape(-1).astype(np.int32)
true_missing_np = true_missing.detach().numpy().reshape(-1).astype(
np.int32)
true_missing_np = np.clip(true_missing_np, 0, num_pred)
pred_missing_np = np.clip(pred_missing_np, 0, num_pred)
for i in range(pred_len):
for pred_j in range(min(num_pred, pred_missing_np[i])):
if true_missing_np[i] > 0:
if isinstance(true_feats[i][true_missing_np[i] - 1],
np.ndarray):
true_feats_tensor = torch.tensor(
true_feats[i][true_missing_np[i] - 1])
if CUDA:
true_feats_tensor = true_feats_tensor.to(device)
else:
true_feats_tensor = true_feats[i][true_missing_np[i] - 1]
loss[i][pred_j] += F.mse_loss(
pred_feats[i][pred_j].unsqueeze(0).float(),
true_feats_tensor.unsqueeze(0).float()).squeeze(0)
for true_k in range(min(num_pred, true_missing_np[i])):
if isinstance(true_feats[i][true_k], np.ndarray):
true_feats_tensor = torch.tensor(true_feats[i][true_k])
if CUDA:
true_feats_tensor = true_feats_tensor.to(device)
else:
true_feats_tensor = true_feats[i][true_k]
loss_ijk = F.mse_loss(
pred_feats[i][pred_j].unsqueeze(0).float(),
true_feats_tensor.unsqueeze(0).float()).squeeze(0)
if torch.sum(loss_ijk) < torch.sum(loss[i][pred_j].data):
loss[i][pred_j] = loss_ijk
else:
continue
return loss.unsqueeze(0).mean().float()