import numpy as np import torch import torch.nn.functional as F def GreedyLoss(pred_feats, true_feats, pred_missing, true_missing, num_pred): r"""Greedy loss is a loss function of cacluating the MSE loss for the feature. https://proceedings.neurips.cc//paper/2021/file/ \ 34adeb8e3242824038aa65460a47c29e-Paper.pdf Fedsageplus models from the "Subgraph Federated Learning with Missing Neighbor Generation" (FedSage+) paper, in NeurIPS'21 Source: https://github.com/zkhku/fedsage Arguments: pred_feats (torch.Tensor): generated missing features true_feats (torch.Tensor): real missing features pred_missing (torch.Tensor): number of predicted missing node true_missing (torch.Tensor): number of missing node num_pred (int): hyperparameters which limit the maximum value of the \ prediction :returns: loss : the Greedy Loss :rtype: torch.FloatTensor """ CUDA, device = (pred_feats.device.type != 'cpu'), pred_feats.device if CUDA: true_missing = true_missing.cpu() pred_missing = pred_missing.cpu() loss = torch.zeros(pred_feats.shape) if CUDA: loss = loss.to(device) pred_len = len(pred_feats) pred_missing_np = np.round( pred_missing.detach().numpy()).reshape(-1).astype(np.int32) true_missing_np = true_missing.detach().numpy().reshape(-1).astype( np.int32) true_missing_np = np.clip(true_missing_np, 0, num_pred) pred_missing_np = np.clip(pred_missing_np, 0, num_pred) for i in range(pred_len): for pred_j in range(min(num_pred, pred_missing_np[i])): if true_missing_np[i] > 0: if isinstance(true_feats[i][true_missing_np[i] - 1], np.ndarray): true_feats_tensor = torch.tensor( true_feats[i][true_missing_np[i] - 1]) if CUDA: true_feats_tensor = true_feats_tensor.to(device) else: true_feats_tensor = true_feats[i][true_missing_np[i] - 1] loss[i][pred_j] += F.mse_loss( pred_feats[i][pred_j].unsqueeze(0).float(), true_feats_tensor.unsqueeze(0).float()).squeeze(0) for true_k in range(min(num_pred, true_missing_np[i])): if isinstance(true_feats[i][true_k], np.ndarray): true_feats_tensor = torch.tensor(true_feats[i][true_k]) if CUDA: true_feats_tensor = true_feats_tensor.to(device) else: true_feats_tensor = true_feats[i][true_k] loss_ijk = F.mse_loss( pred_feats[i][pred_j].unsqueeze(0).float(), true_feats_tensor.unsqueeze(0).float()).squeeze(0) if torch.sum(loss_ijk) < torch.sum(loss[i][pred_j].data): loss[i][pred_j] = loss_ijk else: continue return loss.unsqueeze(0).mean().float()