FS-TFP/federatedscope/autotune/pfedhpo/utils.py

368 lines
12 KiB
Python

import torch.nn.functional as F
from torch import nn
from torch.nn.utils import spectral_norm
import torch
from torch.distributions.multivariate_normal import MultivariateNormal
import numpy as np
import torch as pt
from collections import namedtuple
class EncNet(nn.Module):
def __init__(self, in_channel, num_params, hid_dim=64):
super(EncNet, self).__init__()
self.num_params = num_params
self.fc_layer = nn.Sequential(
spectral_norm(nn.Linear(in_channel, hid_dim, bias=False)),
nn.ReLU(inplace=True),
spectral_norm(nn.Linear(hid_dim, num_params, bias=False)),
)
def forward(self, client_enc):
mean_update = self.fc_layer(client_enc)
return mean_update
class PolicyNet(nn.Module):
def __init__(self, in_channel, num_params, hid_dim=32):
super(PolicyNet, self).__init__()
self.num_params = num_params
self.fc_layer = nn.Sequential(
spectral_norm(nn.Linear(in_channel, hid_dim)),
nn.ReLU(inplace=True),
spectral_norm(nn.Linear(hid_dim, hid_dim)),
nn.ReLU(inplace=True),
spectral_norm(nn.Linear(hid_dim, num_params)),
)
def forward(self, client_enc):
mean_update = self.fc_layer(client_enc)
return mean_update
class DisHyperNet(nn.Module):
def __init__(
self,
encoding,
cands,
n_clients,
device,
):
super(DisHyperNet, self).__init__()
num_params = len(cands)
self.dim = input_dim = encoding.shape[1]
self.encoding = torch.nn.Parameter(encoding, requires_grad=True)
self.EncNet = EncNet(input_dim, 64)
loss_type = 'sphereface'
self.enc_loss = AngularPenaltySMLoss(64, 10, loss_type)
if loss_type == 'sphereface':
self.reg_alpha = 0.1
if loss_type == 'cosface':
self.reg_alpha = 0.2
if loss_type == 'arcface':
self.reg_alpha = 0.1
self.out = nn.ModuleList()
for k, v in cands.items():
self.out.append(
nn.Sequential(nn.Linear(64, len(v), bias=False), nn.Softmax()))
def forward(self):
client_enc = self.EncNet(self.encoding)
client_enc_reg = 0
logits = []
for module in self.out:
out = module(client_enc)
# out = torch.cat([out]*10, dim=0)
logits.append(out)
return logits, client_enc_reg
class HyperNet(nn.Module):
def __init__(self, encoding, num_params, n_clients, device, var):
super(HyperNet, self).__init__()
self.dim = input_dim = encoding.shape[1]
self.var = var
self.encoding = torch.nn.Parameter(encoding, requires_grad=True)
self.mean = torch.zeros((n_clients, num_params)).to(device) + 0.5
self.EncNet = EncNet(input_dim, num_params)
self.meanNet = PolicyNet(num_params, num_params)
self.combine = nn.Sequential(nn.Linear(num_params * 2, num_params),
nn.Sigmoid())
self.alpha = 0.8
def forward(self):
client_enc = self.EncNet(self.encoding)
mean_update = self.meanNet(self.mean)
mean = self.combine(torch.cat([client_enc, mean_update], dim=-1))
cov_matrix = torch.eye(mean.shape[-1]).to(mean.device) * self.var
dist = MultivariateNormal(loc=mean, covariance_matrix=cov_matrix)
sample = dist.sample()
sample = torch.clamp(sample, 0., 1.)
logprob = dist.log_prob(sample)
entropy = dist.entropy()
self.mean.data.copy_(mean.data)
return sample, logprob, entropy
def parse_pbounds(search_space):
pbounds = {}
for k, v in search_space.items():
if not (hasattr(v, 'lower') and hasattr(v, 'upper')):
raise ValueError("Unsupported hyper type {}".format(type(v)))
else:
if v.log:
l, u = np.log10(v.lower), np.log10(v.upper)
else:
l, u = v.lower, v.upper
pbounds[k] = (l, u)
return pbounds
def map_value_to_param(x, pbounds, ss):
x = np.array(x).reshape(-1)
assert len(x) == len(pbounds)
params = {}
for i, (k, b) in zip(range(len(x)), pbounds.items()):
p_inst = ss[k]
l, u = b
p = float(1. * x[i] * (u - l) + l)
if p_inst.log:
p = 10**p
params[k] = int(p) if 'int' in str(type(p_inst)).lower() else p
return params
class AngularPenaltySMLoss(nn.Module):
def __init__(self,
in_features,
out_features,
loss_type='arcface',
eps=1e-7,
s=None,
m=None):
super(AngularPenaltySMLoss, self).__init__()
loss_type = loss_type.lower()
assert loss_type in ['arcface', 'sphereface', 'cosface']
if loss_type == 'arcface':
self.s = 64.0 if not s else s
self.m = 0.5 if not m else m
if loss_type == 'sphereface':
self.s = 64.0 if not s else s
self.m = 1.35 if not m else m
if loss_type == 'cosface':
self.s = 30.0 if not s else s
self.m = 0.4 if not m else m
self.loss_type = loss_type
self.in_features = in_features
self.out_features = out_features
self.fc = nn.Linear(in_features, out_features, bias=False)
self.eps = eps
def forward(self, x, labels):
assert len(x) == len(labels)
assert torch.min(labels) >= 0
assert torch.max(labels) < self.out_features
for W in self.fc.parameters():
W = F.normalize(W, p=2, dim=1)
x = F.normalize(x, p=2, dim=1)
wf = self.fc(x)
if self.loss_type == 'cosface':
numerator = self.s * (torch.diagonal(wf.transpose(0, 1)[labels]) -
self.m)
if self.loss_type == 'arcface':
numerator = self.s * torch.cos(
torch.acos(
torch.clamp(torch.diagonal(wf.transpose(0, 1)[labels]),
-1. + self.eps, 1 - self.eps)) + self.m)
if self.loss_type == 'sphereface':
numerator = self.s * torch.cos(self.m * torch.acos(
torch.clamp(torch.diagonal(wf.transpose(0, 1)[labels]),
-1. + self.eps, 1 - self.eps)))
excl = torch.cat([
torch.cat((wf[i, :y], wf[i, y + 1:])).unsqueeze(0)
for i, y in enumerate(labels)
],
dim=0)
denominator = torch.exp(numerator) + torch.sum(
torch.exp(self.s * excl), dim=1)
L = numerator - torch.log(denominator)
return -torch.mean(L)
def flat_data(data, labels, device, n_labels=10, add_label=False):
bs = data.shape[0]
if add_label:
gen_one_hots = pt.zeros(bs, n_labels, device=device)
gen_one_hots.scatter_(1, labels[:, None], 1)
labels = gen_one_hots
return pt.cat([pt.reshape(data, (bs, -1)), labels], dim=1)
else:
if len(data.shape) > 2:
return pt.reshape(data, (bs, -1))
else:
return data
rff_param_tuple = namedtuple('rff_params', ['w', 'b'])
def rff_sphere(x, rff_params):
w = rff_params.w
xwt = pt.mm(x, w.t())
z_1 = pt.cos(xwt)
z_2 = pt.sin(xwt)
z_cat = pt.cat((z_1, z_2), 1)
norm_const = pt.sqrt(pt.tensor(w.shape[0]).to(pt.float32))
z = z_cat / norm_const # w.shape[0] == n_features / 2
return z
def weights_sphere(d_rff, d_enc, sig, device, seed=1234):
np.random.seed(seed)
freq = np.random.randn(d_rff // 2, d_enc) / np.sqrt(sig)
w_freq = pt.tensor(freq).to(pt.float32).to(device)
return rff_param_tuple(w=w_freq, b=None)
def rff_rahimi_recht(x, rff_params):
w = rff_params.w
b = rff_params.b
xwt = pt.mm(x, w.t()) + b
z = pt.cos(xwt)
z = z * pt.sqrt(pt.tensor(2. / w.shape[0]).to(pt.float32))
return z
def weights_rahimi_recht(d_rff, d_enc, sig, device, seed=1234):
np.random.seed(seed)
w_freq = pt.tensor(np.random.randn(d_rff, d_enc) / np.sqrt(sig)).to(
pt.float32).to(device)
b_freq = pt.tensor(np.random.rand(d_rff) * (2 * np.pi * sig)).to(device)
return rff_param_tuple(w=w_freq, b=b_freq)
def data_label_embedding(data,
labels,
rff_params,
mmd_type,
labels_to_one_hot=False,
n_labels=None,
device=None,
reduce='mean'):
assert reduce in {'mean', 'sum'}
if labels_to_one_hot:
batch_size = data.shape[0]
one_hots = pt.zeros(batch_size, n_labels, device=device)
one_hots.scatter_(1, labels[:, None], 1)
labels = one_hots
data_embedding = rff_sphere(data, rff_params) \
if mmd_type == 'sphere' else rff_rahimi_recht(data, rff_params)
embedding = pt.einsum('ki,kj->kij', [data_embedding, labels])
return pt.mean(embedding, 0) if reduce == 'mean' else pt.sum(embedding, 0)
def noisy_dataset_embedding(train_loader,
d_enc,
sig,
d_rff,
device,
n_labels,
noise_factor,
mmd_type,
sum_frequency=25,
graph=False):
emb_acc = []
n_data = 0
if mmd_type == 'sphere':
w_freq = weights_sphere(d_rff, d_enc, sig, device, seed=1234)
else:
w_freq = weights_rahimi_recht(d_rff, d_enc, sig, device, seed=1234)
if graph:
for data in train_loader:
data, labels = data.x.to(device), data.y.to(device).reshape(-1)
d_enc = data.shape[-1]
if mmd_type == 'sphere':
w_freq = weights_sphere(d_rff, d_enc, sig, device, seed=1234)
else:
w_freq = weights_rahimi_recht(d_rff,
d_enc,
sig,
device,
seed=1234)
data = flat_data(data,
labels,
device,
n_labels=n_labels,
add_label=False)
emb_acc.append(
data_label_embedding(data,
labels,
w_freq,
mmd_type,
labels_to_one_hot=True,
n_labels=n_labels,
device=device,
reduce='sum'))
n_data += data.shape[0]
if len(emb_acc) > sum_frequency:
emb_acc = [pt.sum(pt.stack(emb_acc), 0)]
else:
for data, labels in train_loader:
data, labels = data.to(device), labels.to(device)
data = flat_data(data,
labels,
device,
n_labels=n_labels,
add_label=False)
emb_acc.append(
data_label_embedding(data,
labels,
w_freq,
mmd_type,
labels_to_one_hot=True,
n_labels=n_labels,
device=device,
reduce='sum'))
n_data += data.shape[0]
if len(emb_acc) > sum_frequency:
emb_acc = [pt.sum(pt.stack(emb_acc), 0)]
emb_acc = pt.sum(pt.stack(emb_acc), 0) / n_data
noise = pt.randn(d_rff, n_labels,
device=device) * (2 * noise_factor / n_data)
noisy_emb = emb_acc + noise
return noisy_emb
def merge_dict(dict1, dict2):
# Merge results for history
for key, value in dict2.items():
if key not in dict1:
if isinstance(value, dict):
dict1[key] = merge_dict({}, value)
else:
dict1[key] = [value]
else:
if isinstance(value, dict):
merge_dict(dict1[key], value)
else:
dict1[key].append(value)
return dict1