91 lines
3.1 KiB
Python
91 lines
3.1 KiB
Python
import contextlib
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch_geometric.data.batch import Batch
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _disable_tracking_bn_stats(model):
|
|
def switch_attr(m):
|
|
if hasattr(m, 'track_running_stats'):
|
|
m.track_running_stats ^= True
|
|
|
|
model.apply(switch_attr)
|
|
yield
|
|
model.apply(switch_attr)
|
|
|
|
|
|
def _l2_normalize(d):
|
|
d_reshaped = d.view(d.shape[0], -1, *(1 for _ in range(d.dim() - 2)))
|
|
d /= torch.norm(d_reshaped, dim=1, keepdim=True) + 1e-8
|
|
return d
|
|
|
|
|
|
class VATLoss(nn.Module):
|
|
def __init__(self, xi=1e-3, eps=2.5, ip=1):
|
|
r"""VAT loss
|
|
Source: https://github.com/lyakaap/VAT-pytorch
|
|
|
|
Arguments:
|
|
xi: hyperparameter of VAT in Eq.9, default: 0.0001
|
|
eps: hyperparameter of VAT in Eq.9, default: 2.5
|
|
ip: iteration times of computing adv noise
|
|
|
|
Returns:
|
|
loss : the VAT Loss
|
|
|
|
"""
|
|
super(VATLoss, self).__init__()
|
|
self.xi = xi
|
|
self.eps = eps
|
|
self.ip = ip
|
|
|
|
def forward(self, model, graph, criterion):
|
|
pred = model(graph)
|
|
if criterion._get_name() == 'CrossEntropyLoss':
|
|
pred = torch.max(pred, dim=1).indices.long().view(-1)
|
|
|
|
# prepare random unit tensor
|
|
nodefea = graph.x
|
|
dn = torch.rand(nodefea.shape).sub(0.5).to(nodefea.device)
|
|
dn = _l2_normalize(dn)
|
|
|
|
with _disable_tracking_bn_stats(model):
|
|
# calc adversarial direction
|
|
with torch.enable_grad():
|
|
for _ in range(self.ip):
|
|
dn.requires_grad_()
|
|
x_neighbor = Batch(x=nodefea + self.xi * dn,
|
|
edge_index=graph.edge_index,
|
|
y=graph.y,
|
|
edge_attr=graph.edge_attr,
|
|
batch=graph.batch)
|
|
pred_hat = model(x_neighbor)
|
|
# logp_hat = F.log_softmax(pred_hat, dim=1)
|
|
# adv_distance = F.kl_div(logp_hat, logp,
|
|
# reduction='batchmean')
|
|
# adv_distance = ((pred - pred_hat) ** 2).sum(
|
|
# axis=0).sqrt()
|
|
adv_distance = criterion(pred_hat, pred)
|
|
# adv_distance.backward()
|
|
# dn = _l2_normalize(dn.grad)
|
|
dn = _l2_normalize(
|
|
torch.autograd.grad(outputs=adv_distance,
|
|
inputs=dn,
|
|
retain_graph=True)[0])
|
|
model.zero_grad()
|
|
del x_neighbor, pred_hat, adv_distance
|
|
|
|
# calc LDS
|
|
rn_adv = dn * self.eps
|
|
x_adv = Batch(x=nodefea + rn_adv,
|
|
edge_index=graph.edge_index,
|
|
y=graph.y,
|
|
edge_attr=graph.edge_attr,
|
|
batch=graph.batch)
|
|
pred_hat = model(x_adv)
|
|
lds = criterion(pred_hat, pred)
|
|
|
|
return lds
|