368 lines
12 KiB
Python
368 lines
12 KiB
Python
import torch.nn.functional as F
|
|
from torch import nn
|
|
from torch.nn.utils import spectral_norm
|
|
import torch
|
|
from torch.distributions.multivariate_normal import MultivariateNormal
|
|
import numpy as np
|
|
import torch as pt
|
|
from collections import namedtuple
|
|
|
|
|
|
class EncNet(nn.Module):
|
|
def __init__(self, in_channel, num_params, hid_dim=64):
|
|
super(EncNet, self).__init__()
|
|
self.num_params = num_params
|
|
self.fc_layer = nn.Sequential(
|
|
spectral_norm(nn.Linear(in_channel, hid_dim, bias=False)),
|
|
nn.ReLU(inplace=True),
|
|
spectral_norm(nn.Linear(hid_dim, num_params, bias=False)),
|
|
)
|
|
|
|
def forward(self, client_enc):
|
|
mean_update = self.fc_layer(client_enc)
|
|
return mean_update
|
|
|
|
|
|
class PolicyNet(nn.Module):
|
|
def __init__(self, in_channel, num_params, hid_dim=32):
|
|
super(PolicyNet, self).__init__()
|
|
self.num_params = num_params
|
|
|
|
self.fc_layer = nn.Sequential(
|
|
spectral_norm(nn.Linear(in_channel, hid_dim)),
|
|
nn.ReLU(inplace=True),
|
|
spectral_norm(nn.Linear(hid_dim, hid_dim)),
|
|
nn.ReLU(inplace=True),
|
|
spectral_norm(nn.Linear(hid_dim, num_params)),
|
|
)
|
|
|
|
def forward(self, client_enc):
|
|
mean_update = self.fc_layer(client_enc)
|
|
return mean_update
|
|
|
|
|
|
class DisHyperNet(nn.Module):
|
|
def __init__(
|
|
self,
|
|
encoding,
|
|
cands,
|
|
n_clients,
|
|
device,
|
|
):
|
|
super(DisHyperNet, self).__init__()
|
|
num_params = len(cands)
|
|
self.dim = input_dim = encoding.shape[1]
|
|
self.encoding = torch.nn.Parameter(encoding, requires_grad=True)
|
|
self.EncNet = EncNet(input_dim, 64)
|
|
loss_type = 'sphereface'
|
|
self.enc_loss = AngularPenaltySMLoss(64, 10, loss_type)
|
|
if loss_type == 'sphereface':
|
|
self.reg_alpha = 0.1
|
|
if loss_type == 'cosface':
|
|
self.reg_alpha = 0.2
|
|
if loss_type == 'arcface':
|
|
self.reg_alpha = 0.1
|
|
|
|
self.out = nn.ModuleList()
|
|
for k, v in cands.items():
|
|
self.out.append(
|
|
nn.Sequential(nn.Linear(64, len(v), bias=False), nn.Softmax()))
|
|
|
|
def forward(self):
|
|
client_enc = self.EncNet(self.encoding)
|
|
client_enc_reg = 0
|
|
logits = []
|
|
for module in self.out:
|
|
out = module(client_enc)
|
|
# out = torch.cat([out]*10, dim=0)
|
|
logits.append(out)
|
|
return logits, client_enc_reg
|
|
|
|
|
|
class HyperNet(nn.Module):
|
|
def __init__(self, encoding, num_params, n_clients, device, var):
|
|
super(HyperNet, self).__init__()
|
|
self.dim = input_dim = encoding.shape[1]
|
|
self.var = var
|
|
self.encoding = torch.nn.Parameter(encoding, requires_grad=True)
|
|
self.mean = torch.zeros((n_clients, num_params)).to(device) + 0.5
|
|
|
|
self.EncNet = EncNet(input_dim, num_params)
|
|
self.meanNet = PolicyNet(num_params, num_params)
|
|
self.combine = nn.Sequential(nn.Linear(num_params * 2, num_params),
|
|
nn.Sigmoid())
|
|
|
|
self.alpha = 0.8
|
|
|
|
def forward(self):
|
|
client_enc = self.EncNet(self.encoding)
|
|
mean_update = self.meanNet(self.mean)
|
|
mean = self.combine(torch.cat([client_enc, mean_update], dim=-1))
|
|
|
|
cov_matrix = torch.eye(mean.shape[-1]).to(mean.device) * self.var
|
|
dist = MultivariateNormal(loc=mean, covariance_matrix=cov_matrix)
|
|
sample = dist.sample()
|
|
sample = torch.clamp(sample, 0., 1.)
|
|
logprob = dist.log_prob(sample)
|
|
entropy = dist.entropy()
|
|
self.mean.data.copy_(mean.data)
|
|
|
|
return sample, logprob, entropy
|
|
|
|
|
|
def parse_pbounds(search_space):
|
|
pbounds = {}
|
|
for k, v in search_space.items():
|
|
if not (hasattr(v, 'lower') and hasattr(v, 'upper')):
|
|
raise ValueError("Unsupported hyper type {}".format(type(v)))
|
|
else:
|
|
if v.log:
|
|
l, u = np.log10(v.lower), np.log10(v.upper)
|
|
else:
|
|
l, u = v.lower, v.upper
|
|
pbounds[k] = (l, u)
|
|
return pbounds
|
|
|
|
|
|
def map_value_to_param(x, pbounds, ss):
|
|
x = np.array(x).reshape(-1)
|
|
assert len(x) == len(pbounds)
|
|
params = {}
|
|
|
|
for i, (k, b) in zip(range(len(x)), pbounds.items()):
|
|
p_inst = ss[k]
|
|
l, u = b
|
|
p = float(1. * x[i] * (u - l) + l)
|
|
if p_inst.log:
|
|
p = 10**p
|
|
params[k] = int(p) if 'int' in str(type(p_inst)).lower() else p
|
|
return params
|
|
|
|
|
|
class AngularPenaltySMLoss(nn.Module):
|
|
def __init__(self,
|
|
in_features,
|
|
out_features,
|
|
loss_type='arcface',
|
|
eps=1e-7,
|
|
s=None,
|
|
m=None):
|
|
super(AngularPenaltySMLoss, self).__init__()
|
|
loss_type = loss_type.lower()
|
|
assert loss_type in ['arcface', 'sphereface', 'cosface']
|
|
if loss_type == 'arcface':
|
|
self.s = 64.0 if not s else s
|
|
self.m = 0.5 if not m else m
|
|
if loss_type == 'sphereface':
|
|
self.s = 64.0 if not s else s
|
|
self.m = 1.35 if not m else m
|
|
if loss_type == 'cosface':
|
|
self.s = 30.0 if not s else s
|
|
self.m = 0.4 if not m else m
|
|
self.loss_type = loss_type
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.fc = nn.Linear(in_features, out_features, bias=False)
|
|
self.eps = eps
|
|
|
|
def forward(self, x, labels):
|
|
assert len(x) == len(labels)
|
|
assert torch.min(labels) >= 0
|
|
assert torch.max(labels) < self.out_features
|
|
|
|
for W in self.fc.parameters():
|
|
W = F.normalize(W, p=2, dim=1)
|
|
x = F.normalize(x, p=2, dim=1)
|
|
|
|
wf = self.fc(x)
|
|
if self.loss_type == 'cosface':
|
|
numerator = self.s * (torch.diagonal(wf.transpose(0, 1)[labels]) -
|
|
self.m)
|
|
if self.loss_type == 'arcface':
|
|
numerator = self.s * torch.cos(
|
|
torch.acos(
|
|
torch.clamp(torch.diagonal(wf.transpose(0, 1)[labels]),
|
|
-1. + self.eps, 1 - self.eps)) + self.m)
|
|
if self.loss_type == 'sphereface':
|
|
numerator = self.s * torch.cos(self.m * torch.acos(
|
|
torch.clamp(torch.diagonal(wf.transpose(0, 1)[labels]),
|
|
-1. + self.eps, 1 - self.eps)))
|
|
|
|
excl = torch.cat([
|
|
torch.cat((wf[i, :y], wf[i, y + 1:])).unsqueeze(0)
|
|
for i, y in enumerate(labels)
|
|
],
|
|
dim=0)
|
|
denominator = torch.exp(numerator) + torch.sum(
|
|
torch.exp(self.s * excl), dim=1)
|
|
L = numerator - torch.log(denominator)
|
|
return -torch.mean(L)
|
|
|
|
|
|
def flat_data(data, labels, device, n_labels=10, add_label=False):
|
|
bs = data.shape[0]
|
|
if add_label:
|
|
gen_one_hots = pt.zeros(bs, n_labels, device=device)
|
|
gen_one_hots.scatter_(1, labels[:, None], 1)
|
|
labels = gen_one_hots
|
|
return pt.cat([pt.reshape(data, (bs, -1)), labels], dim=1)
|
|
else:
|
|
if len(data.shape) > 2:
|
|
return pt.reshape(data, (bs, -1))
|
|
else:
|
|
return data
|
|
|
|
|
|
rff_param_tuple = namedtuple('rff_params', ['w', 'b'])
|
|
|
|
|
|
def rff_sphere(x, rff_params):
|
|
w = rff_params.w
|
|
xwt = pt.mm(x, w.t())
|
|
z_1 = pt.cos(xwt)
|
|
z_2 = pt.sin(xwt)
|
|
z_cat = pt.cat((z_1, z_2), 1)
|
|
norm_const = pt.sqrt(pt.tensor(w.shape[0]).to(pt.float32))
|
|
z = z_cat / norm_const # w.shape[0] == n_features / 2
|
|
return z
|
|
|
|
|
|
def weights_sphere(d_rff, d_enc, sig, device, seed=1234):
|
|
np.random.seed(seed)
|
|
freq = np.random.randn(d_rff // 2, d_enc) / np.sqrt(sig)
|
|
w_freq = pt.tensor(freq).to(pt.float32).to(device)
|
|
return rff_param_tuple(w=w_freq, b=None)
|
|
|
|
|
|
def rff_rahimi_recht(x, rff_params):
|
|
w = rff_params.w
|
|
b = rff_params.b
|
|
xwt = pt.mm(x, w.t()) + b
|
|
z = pt.cos(xwt)
|
|
z = z * pt.sqrt(pt.tensor(2. / w.shape[0]).to(pt.float32))
|
|
return z
|
|
|
|
|
|
def weights_rahimi_recht(d_rff, d_enc, sig, device, seed=1234):
|
|
np.random.seed(seed)
|
|
w_freq = pt.tensor(np.random.randn(d_rff, d_enc) / np.sqrt(sig)).to(
|
|
pt.float32).to(device)
|
|
b_freq = pt.tensor(np.random.rand(d_rff) * (2 * np.pi * sig)).to(device)
|
|
return rff_param_tuple(w=w_freq, b=b_freq)
|
|
|
|
|
|
def data_label_embedding(data,
|
|
labels,
|
|
rff_params,
|
|
mmd_type,
|
|
labels_to_one_hot=False,
|
|
n_labels=None,
|
|
device=None,
|
|
reduce='mean'):
|
|
assert reduce in {'mean', 'sum'}
|
|
if labels_to_one_hot:
|
|
batch_size = data.shape[0]
|
|
one_hots = pt.zeros(batch_size, n_labels, device=device)
|
|
one_hots.scatter_(1, labels[:, None], 1)
|
|
labels = one_hots
|
|
|
|
data_embedding = rff_sphere(data, rff_params) \
|
|
if mmd_type == 'sphere' else rff_rahimi_recht(data, rff_params)
|
|
embedding = pt.einsum('ki,kj->kij', [data_embedding, labels])
|
|
return pt.mean(embedding, 0) if reduce == 'mean' else pt.sum(embedding, 0)
|
|
|
|
|
|
def noisy_dataset_embedding(train_loader,
|
|
d_enc,
|
|
sig,
|
|
d_rff,
|
|
device,
|
|
n_labels,
|
|
noise_factor,
|
|
mmd_type,
|
|
sum_frequency=25,
|
|
graph=False):
|
|
emb_acc = []
|
|
n_data = 0
|
|
|
|
if mmd_type == 'sphere':
|
|
w_freq = weights_sphere(d_rff, d_enc, sig, device, seed=1234)
|
|
else:
|
|
w_freq = weights_rahimi_recht(d_rff, d_enc, sig, device, seed=1234)
|
|
|
|
if graph:
|
|
for data in train_loader:
|
|
data, labels = data.x.to(device), data.y.to(device).reshape(-1)
|
|
d_enc = data.shape[-1]
|
|
if mmd_type == 'sphere':
|
|
w_freq = weights_sphere(d_rff, d_enc, sig, device, seed=1234)
|
|
else:
|
|
w_freq = weights_rahimi_recht(d_rff,
|
|
d_enc,
|
|
sig,
|
|
device,
|
|
seed=1234)
|
|
|
|
data = flat_data(data,
|
|
labels,
|
|
device,
|
|
n_labels=n_labels,
|
|
add_label=False)
|
|
emb_acc.append(
|
|
data_label_embedding(data,
|
|
labels,
|
|
w_freq,
|
|
mmd_type,
|
|
labels_to_one_hot=True,
|
|
n_labels=n_labels,
|
|
device=device,
|
|
reduce='sum'))
|
|
n_data += data.shape[0]
|
|
|
|
if len(emb_acc) > sum_frequency:
|
|
emb_acc = [pt.sum(pt.stack(emb_acc), 0)]
|
|
|
|
else:
|
|
for data, labels in train_loader:
|
|
data, labels = data.to(device), labels.to(device)
|
|
data = flat_data(data,
|
|
labels,
|
|
device,
|
|
n_labels=n_labels,
|
|
add_label=False)
|
|
emb_acc.append(
|
|
data_label_embedding(data,
|
|
labels,
|
|
w_freq,
|
|
mmd_type,
|
|
labels_to_one_hot=True,
|
|
n_labels=n_labels,
|
|
device=device,
|
|
reduce='sum'))
|
|
n_data += data.shape[0]
|
|
|
|
if len(emb_acc) > sum_frequency:
|
|
emb_acc = [pt.sum(pt.stack(emb_acc), 0)]
|
|
|
|
emb_acc = pt.sum(pt.stack(emb_acc), 0) / n_data
|
|
noise = pt.randn(d_rff, n_labels,
|
|
device=device) * (2 * noise_factor / n_data)
|
|
noisy_emb = emb_acc + noise
|
|
return noisy_emb
|
|
|
|
|
|
def merge_dict(dict1, dict2):
|
|
# Merge results for history
|
|
for key, value in dict2.items():
|
|
if key not in dict1:
|
|
if isinstance(value, dict):
|
|
dict1[key] = merge_dict({}, value)
|
|
else:
|
|
dict1[key] = [value]
|
|
else:
|
|
if isinstance(value, dict):
|
|
merge_dict(dict1[key], value)
|
|
else:
|
|
dict1[key].append(value)
|
|
return dict1
|