FS-TFP/federatedscope/gfl/loss/vat.py

91 lines
3.1 KiB
Python

import contextlib
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data.batch import Batch
@contextlib.contextmanager
def _disable_tracking_bn_stats(model):
def switch_attr(m):
if hasattr(m, 'track_running_stats'):
m.track_running_stats ^= True
model.apply(switch_attr)
yield
model.apply(switch_attr)
def _l2_normalize(d):
d_reshaped = d.view(d.shape[0], -1, *(1 for _ in range(d.dim() - 2)))
d /= torch.norm(d_reshaped, dim=1, keepdim=True) + 1e-8
return d
class VATLoss(nn.Module):
def __init__(self, xi=1e-3, eps=2.5, ip=1):
r"""VAT loss
Source: https://github.com/lyakaap/VAT-pytorch
Arguments:
xi: hyperparameter of VAT in Eq.9, default: 0.0001
eps: hyperparameter of VAT in Eq.9, default: 2.5
ip: iteration times of computing adv noise
Returns:
loss : the VAT Loss
"""
super(VATLoss, self).__init__()
self.xi = xi
self.eps = eps
self.ip = ip
def forward(self, model, graph, criterion):
pred = model(graph)
if criterion._get_name() == 'CrossEntropyLoss':
pred = torch.max(pred, dim=1).indices.long().view(-1)
# prepare random unit tensor
nodefea = graph.x
dn = torch.rand(nodefea.shape).sub(0.5).to(nodefea.device)
dn = _l2_normalize(dn)
with _disable_tracking_bn_stats(model):
# calc adversarial direction
with torch.enable_grad():
for _ in range(self.ip):
dn.requires_grad_()
x_neighbor = Batch(x=nodefea + self.xi * dn,
edge_index=graph.edge_index,
y=graph.y,
edge_attr=graph.edge_attr,
batch=graph.batch)
pred_hat = model(x_neighbor)
# logp_hat = F.log_softmax(pred_hat, dim=1)
# adv_distance = F.kl_div(logp_hat, logp,
# reduction='batchmean')
# adv_distance = ((pred - pred_hat) ** 2).sum(
# axis=0).sqrt()
adv_distance = criterion(pred_hat, pred)
# adv_distance.backward()
# dn = _l2_normalize(dn.grad)
dn = _l2_normalize(
torch.autograd.grad(outputs=adv_distance,
inputs=dn,
retain_graph=True)[0])
model.zero_grad()
del x_neighbor, pred_hat, adv_distance
# calc LDS
rn_adv = dn * self.eps
x_adv = Batch(x=nodefea + rn_adv,
edge_index=graph.edge_index,
y=graph.y,
edge_attr=graph.edge_attr,
batch=graph.batch)
pred_hat = model(x_adv)
lds = criterion(pred_hat, pred)
return lds