71 lines
3.1 KiB
Python
71 lines
3.1 KiB
Python
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def GreedyLoss(pred_feats, true_feats, pred_missing, true_missing, num_pred):
|
|
r"""Greedy loss is a loss function of cacluating the MSE loss for the feature.
|
|
https://proceedings.neurips.cc//paper/2021/file/ \
|
|
34adeb8e3242824038aa65460a47c29e-Paper.pdf
|
|
Fedsageplus models from the "Subgraph Federated Learning with Missing
|
|
Neighbor Generation" (FedSage+) paper, in NeurIPS'21
|
|
Source: https://github.com/zkhku/fedsage
|
|
|
|
Arguments:
|
|
pred_feats (torch.Tensor): generated missing features
|
|
true_feats (torch.Tensor): real missing features
|
|
pred_missing (torch.Tensor): number of predicted missing node
|
|
true_missing (torch.Tensor): number of missing node
|
|
num_pred (int): hyperparameters which limit the maximum value of the \
|
|
prediction
|
|
:returns:
|
|
loss : the Greedy Loss
|
|
:rtype:
|
|
torch.FloatTensor
|
|
"""
|
|
CUDA, device = (pred_feats.device.type != 'cpu'), pred_feats.device
|
|
if CUDA:
|
|
true_missing = true_missing.cpu()
|
|
pred_missing = pred_missing.cpu()
|
|
loss = torch.zeros(pred_feats.shape)
|
|
if CUDA:
|
|
loss = loss.to(device)
|
|
pred_len = len(pred_feats)
|
|
pred_missing_np = np.round(
|
|
pred_missing.detach().numpy()).reshape(-1).astype(np.int32)
|
|
true_missing_np = true_missing.detach().numpy().reshape(-1).astype(
|
|
np.int32)
|
|
true_missing_np = np.clip(true_missing_np, 0, num_pred)
|
|
pred_missing_np = np.clip(pred_missing_np, 0, num_pred)
|
|
for i in range(pred_len):
|
|
for pred_j in range(min(num_pred, pred_missing_np[i])):
|
|
if true_missing_np[i] > 0:
|
|
if isinstance(true_feats[i][true_missing_np[i] - 1],
|
|
np.ndarray):
|
|
true_feats_tensor = torch.tensor(
|
|
true_feats[i][true_missing_np[i] - 1])
|
|
if CUDA:
|
|
true_feats_tensor = true_feats_tensor.to(device)
|
|
else:
|
|
true_feats_tensor = true_feats[i][true_missing_np[i] - 1]
|
|
loss[i][pred_j] += F.mse_loss(
|
|
pred_feats[i][pred_j].unsqueeze(0).float(),
|
|
true_feats_tensor.unsqueeze(0).float()).squeeze(0)
|
|
|
|
for true_k in range(min(num_pred, true_missing_np[i])):
|
|
if isinstance(true_feats[i][true_k], np.ndarray):
|
|
true_feats_tensor = torch.tensor(true_feats[i][true_k])
|
|
if CUDA:
|
|
true_feats_tensor = true_feats_tensor.to(device)
|
|
else:
|
|
true_feats_tensor = true_feats[i][true_k]
|
|
|
|
loss_ijk = F.mse_loss(
|
|
pred_feats[i][pred_j].unsqueeze(0).float(),
|
|
true_feats_tensor.unsqueeze(0).float()).squeeze(0)
|
|
if torch.sum(loss_ijk) < torch.sum(loss[i][pred_j].data):
|
|
loss[i][pred_j] = loss_ijk
|
|
else:
|
|
continue
|
|
return loss.unsqueeze(0).mean().float()
|