import torch.nn.functional as F from torch import nn from torch.nn.utils import spectral_norm import torch from torch.distributions.multivariate_normal import MultivariateNormal import numpy as np import torch as pt from collections import namedtuple class EncNet(nn.Module): def __init__(self, in_channel, num_params, hid_dim=64): super(EncNet, self).__init__() self.num_params = num_params self.fc_layer = nn.Sequential( spectral_norm(nn.Linear(in_channel, hid_dim, bias=False)), nn.ReLU(inplace=True), spectral_norm(nn.Linear(hid_dim, num_params, bias=False)), ) def forward(self, client_enc): mean_update = self.fc_layer(client_enc) return mean_update class PolicyNet(nn.Module): def __init__(self, in_channel, num_params, hid_dim=32): super(PolicyNet, self).__init__() self.num_params = num_params self.fc_layer = nn.Sequential( spectral_norm(nn.Linear(in_channel, hid_dim)), nn.ReLU(inplace=True), spectral_norm(nn.Linear(hid_dim, hid_dim)), nn.ReLU(inplace=True), spectral_norm(nn.Linear(hid_dim, num_params)), ) def forward(self, client_enc): mean_update = self.fc_layer(client_enc) return mean_update class DisHyperNet(nn.Module): def __init__( self, encoding, cands, n_clients, device, ): super(DisHyperNet, self).__init__() num_params = len(cands) self.dim = input_dim = encoding.shape[1] self.encoding = torch.nn.Parameter(encoding, requires_grad=True) self.EncNet = EncNet(input_dim, 64) loss_type = 'sphereface' self.enc_loss = AngularPenaltySMLoss(64, 10, loss_type) if loss_type == 'sphereface': self.reg_alpha = 0.1 if loss_type == 'cosface': self.reg_alpha = 0.2 if loss_type == 'arcface': self.reg_alpha = 0.1 self.out = nn.ModuleList() for k, v in cands.items(): self.out.append( nn.Sequential(nn.Linear(64, len(v), bias=False), nn.Softmax())) def forward(self): client_enc = self.EncNet(self.encoding) client_enc_reg = 0 logits = [] for module in self.out: out = module(client_enc) # out = torch.cat([out]*10, dim=0) logits.append(out) return logits, client_enc_reg class HyperNet(nn.Module): def __init__(self, encoding, num_params, n_clients, device, var): super(HyperNet, self).__init__() self.dim = input_dim = encoding.shape[1] self.var = var self.encoding = torch.nn.Parameter(encoding, requires_grad=True) self.mean = torch.zeros((n_clients, num_params)).to(device) + 0.5 self.EncNet = EncNet(input_dim, num_params) self.meanNet = PolicyNet(num_params, num_params) self.combine = nn.Sequential(nn.Linear(num_params * 2, num_params), nn.Sigmoid()) self.alpha = 0.8 def forward(self): client_enc = self.EncNet(self.encoding) mean_update = self.meanNet(self.mean) mean = self.combine(torch.cat([client_enc, mean_update], dim=-1)) cov_matrix = torch.eye(mean.shape[-1]).to(mean.device) * self.var dist = MultivariateNormal(loc=mean, covariance_matrix=cov_matrix) sample = dist.sample() sample = torch.clamp(sample, 0., 1.) logprob = dist.log_prob(sample) entropy = dist.entropy() self.mean.data.copy_(mean.data) return sample, logprob, entropy def parse_pbounds(search_space): pbounds = {} for k, v in search_space.items(): if not (hasattr(v, 'lower') and hasattr(v, 'upper')): raise ValueError("Unsupported hyper type {}".format(type(v))) else: if v.log: l, u = np.log10(v.lower), np.log10(v.upper) else: l, u = v.lower, v.upper pbounds[k] = (l, u) return pbounds def map_value_to_param(x, pbounds, ss): x = np.array(x).reshape(-1) assert len(x) == len(pbounds) params = {} for i, (k, b) in zip(range(len(x)), pbounds.items()): p_inst = ss[k] l, u = b p = float(1. * x[i] * (u - l) + l) if p_inst.log: p = 10**p params[k] = int(p) if 'int' in str(type(p_inst)).lower() else p return params class AngularPenaltySMLoss(nn.Module): def __init__(self, in_features, out_features, loss_type='arcface', eps=1e-7, s=None, m=None): super(AngularPenaltySMLoss, self).__init__() loss_type = loss_type.lower() assert loss_type in ['arcface', 'sphereface', 'cosface'] if loss_type == 'arcface': self.s = 64.0 if not s else s self.m = 0.5 if not m else m if loss_type == 'sphereface': self.s = 64.0 if not s else s self.m = 1.35 if not m else m if loss_type == 'cosface': self.s = 30.0 if not s else s self.m = 0.4 if not m else m self.loss_type = loss_type self.in_features = in_features self.out_features = out_features self.fc = nn.Linear(in_features, out_features, bias=False) self.eps = eps def forward(self, x, labels): assert len(x) == len(labels) assert torch.min(labels) >= 0 assert torch.max(labels) < self.out_features for W in self.fc.parameters(): W = F.normalize(W, p=2, dim=1) x = F.normalize(x, p=2, dim=1) wf = self.fc(x) if self.loss_type == 'cosface': numerator = self.s * (torch.diagonal(wf.transpose(0, 1)[labels]) - self.m) if self.loss_type == 'arcface': numerator = self.s * torch.cos( torch.acos( torch.clamp(torch.diagonal(wf.transpose(0, 1)[labels]), -1. + self.eps, 1 - self.eps)) + self.m) if self.loss_type == 'sphereface': numerator = self.s * torch.cos(self.m * torch.acos( torch.clamp(torch.diagonal(wf.transpose(0, 1)[labels]), -1. + self.eps, 1 - self.eps))) excl = torch.cat([ torch.cat((wf[i, :y], wf[i, y + 1:])).unsqueeze(0) for i, y in enumerate(labels) ], dim=0) denominator = torch.exp(numerator) + torch.sum( torch.exp(self.s * excl), dim=1) L = numerator - torch.log(denominator) return -torch.mean(L) def flat_data(data, labels, device, n_labels=10, add_label=False): bs = data.shape[0] if add_label: gen_one_hots = pt.zeros(bs, n_labels, device=device) gen_one_hots.scatter_(1, labels[:, None], 1) labels = gen_one_hots return pt.cat([pt.reshape(data, (bs, -1)), labels], dim=1) else: if len(data.shape) > 2: return pt.reshape(data, (bs, -1)) else: return data rff_param_tuple = namedtuple('rff_params', ['w', 'b']) def rff_sphere(x, rff_params): w = rff_params.w xwt = pt.mm(x, w.t()) z_1 = pt.cos(xwt) z_2 = pt.sin(xwt) z_cat = pt.cat((z_1, z_2), 1) norm_const = pt.sqrt(pt.tensor(w.shape[0]).to(pt.float32)) z = z_cat / norm_const # w.shape[0] == n_features / 2 return z def weights_sphere(d_rff, d_enc, sig, device, seed=1234): np.random.seed(seed) freq = np.random.randn(d_rff // 2, d_enc) / np.sqrt(sig) w_freq = pt.tensor(freq).to(pt.float32).to(device) return rff_param_tuple(w=w_freq, b=None) def rff_rahimi_recht(x, rff_params): w = rff_params.w b = rff_params.b xwt = pt.mm(x, w.t()) + b z = pt.cos(xwt) z = z * pt.sqrt(pt.tensor(2. / w.shape[0]).to(pt.float32)) return z def weights_rahimi_recht(d_rff, d_enc, sig, device, seed=1234): np.random.seed(seed) w_freq = pt.tensor(np.random.randn(d_rff, d_enc) / np.sqrt(sig)).to( pt.float32).to(device) b_freq = pt.tensor(np.random.rand(d_rff) * (2 * np.pi * sig)).to(device) return rff_param_tuple(w=w_freq, b=b_freq) def data_label_embedding(data, labels, rff_params, mmd_type, labels_to_one_hot=False, n_labels=None, device=None, reduce='mean'): assert reduce in {'mean', 'sum'} if labels_to_one_hot: batch_size = data.shape[0] one_hots = pt.zeros(batch_size, n_labels, device=device) one_hots.scatter_(1, labels[:, None], 1) labels = one_hots data_embedding = rff_sphere(data, rff_params) \ if mmd_type == 'sphere' else rff_rahimi_recht(data, rff_params) embedding = pt.einsum('ki,kj->kij', [data_embedding, labels]) return pt.mean(embedding, 0) if reduce == 'mean' else pt.sum(embedding, 0) def noisy_dataset_embedding(train_loader, d_enc, sig, d_rff, device, n_labels, noise_factor, mmd_type, sum_frequency=25, graph=False): emb_acc = [] n_data = 0 if mmd_type == 'sphere': w_freq = weights_sphere(d_rff, d_enc, sig, device, seed=1234) else: w_freq = weights_rahimi_recht(d_rff, d_enc, sig, device, seed=1234) if graph: for data in train_loader: data, labels = data.x.to(device), data.y.to(device).reshape(-1) d_enc = data.shape[-1] if mmd_type == 'sphere': w_freq = weights_sphere(d_rff, d_enc, sig, device, seed=1234) else: w_freq = weights_rahimi_recht(d_rff, d_enc, sig, device, seed=1234) data = flat_data(data, labels, device, n_labels=n_labels, add_label=False) emb_acc.append( data_label_embedding(data, labels, w_freq, mmd_type, labels_to_one_hot=True, n_labels=n_labels, device=device, reduce='sum')) n_data += data.shape[0] if len(emb_acc) > sum_frequency: emb_acc = [pt.sum(pt.stack(emb_acc), 0)] else: for data, labels in train_loader: data, labels = data.to(device), labels.to(device) data = flat_data(data, labels, device, n_labels=n_labels, add_label=False) emb_acc.append( data_label_embedding(data, labels, w_freq, mmd_type, labels_to_one_hot=True, n_labels=n_labels, device=device, reduce='sum')) n_data += data.shape[0] if len(emb_acc) > sum_frequency: emb_acc = [pt.sum(pt.stack(emb_acc), 0)] emb_acc = pt.sum(pt.stack(emb_acc), 0) / n_data noise = pt.randn(d_rff, n_labels, device=device) * (2 * noise_factor / n_data) noisy_emb = emb_acc + noise return noisy_emb def merge_dict(dict1, dict2): # Merge results for history for key, value in dict2.items(): if key not in dict1: if isinstance(value, dict): dict1[key] = merge_dict({}, value) else: dict1[key] = [value] else: if isinstance(value, dict): merge_dict(dict1[key], value) else: dict1[key].append(value) return dict1