FS-TFP/federatedscope/nlp/hetero_tasks/aggregator/aggregator.py

287 lines
12 KiB
Python

import os
import re
import copy
import random
import torch
import numpy as np
import logging
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics.pairwise import cosine_distances
from federatedscope.core.aggregators import ClientsAvgAggregator
from federatedscope.core.configs.config import global_cfg
logger = logging.getLogger(__name__)
class ATCAggregator(ClientsAvgAggregator):
def __init__(self, model=None, config=None, device='cpu'):
super().__init__(model=model, config=config, device=device)
self.client_num = config.federate.client_num
self.task = config.model.task
self.pretrain_tasks = config.model.pretrain_tasks
self.num_agg_groups = config.aggregator.num_agg_groups
self.num_agg_topk = config.aggregator.num_agg_topk
self.inside_weight = config.aggregator.inside_weight
self.outside_weight = config.aggregator.outside_weight
self.models = []
self.neighbors = {}
self.client_id2group = [None for _ in range(self.client_num)]
self.client_id2topk = [[] for _ in range(self.client_num)]
self.client_id2all = [[] for _ in range(self.client_num)]
self.use_contrastive_loss = config.model.use_contrastive_loss
if self.use_contrastive_loss:
self.contrast_monitor = None
def update_models(self, models):
self.models = models
def update_neighbors(self, neighbors):
self.neighbors = neighbors
def update_contrast_monitor(self, contrast_monitor):
self.contrast_monitor = contrast_monitor
def aggregate(self, agg_info):
models = agg_info["client_feedback"]
recover_fun = agg_info['recover_fun'] if (
'recover_fun' in agg_info and global_cfg.federate.use_ss) \
else None
avg_models, tasks = self._para_weighted_avg(models,
recover_fun=recover_fun)
return avg_models, tasks
def update(self, model_parameters):
for i, param in enumerate(model_parameters):
self.models[i].load_state_dict(param, strict=False)
def save_model(self, path, cur_round=-1):
assert self.models is not None
path = os.path.join(path, 'global')
os.makedirs(path, exist_ok=True)
neighbor_ids = sorted(list(self.neighbors.keys()))
for i, model in enumerate(self.models):
ckpt = {'cur_round': cur_round, 'model': model.state_dict()}
torch.save(
ckpt,
os.path.join(path,
'global_model_{}.pt'.format(neighbor_ids[i])))
def load_model(self, path):
if getattr(self, 'models', None):
round = None
global_dir = os.path.join(path, 'global')
client_dir = os.path.join(path, 'client')
neighbor_ids = sorted([
int(re.search(r'model_(\d+).pt', x).groups()[0])
for x in os.listdir(global_dir)
])
assert len(neighbor_ids) == len(self.models)
for i, model in enumerate(self.models):
cur_global_path = os.path.join(
global_dir, 'global_model_{}.pt'.format(neighbor_ids[i]))
cur_client_path = os.path.join(
client_dir, 'client_model_{}.pt'.format(neighbor_ids[i]))
if os.path.exists(cur_global_path):
model_ckpt = model.state_dict()
logger.info(
'Loading model from \'{}\''.format(cur_global_path))
global_ckpt = torch.load(cur_global_path,
map_location=self.device)
model_ckpt.update(global_ckpt['model'])
if os.path.exists(cur_client_path):
logger.info('Updating model from \'{}\''.format(
cur_client_path))
client_ckpt = torch.load(cur_client_path,
map_location=self.device)
model_ckpt.update(client_ckpt['model'])
self.models[i].load_state_dict(model_ckpt)
round = global_ckpt['cur_round']
else:
raise ValueError(
"The file {} does NOT exist".format(cur_global_path))
return round
def _compute_client_groups(self, models):
tasks = [None for _ in range(self.client_num)]
if self.task == 'pretrain':
grads = torch.stack([
torch.cat([g.view(-1) for g in model['model_grads'].values()])
for model in models
])
clustering = AgglomerativeClustering(
n_clusters=self.num_agg_groups,
affinity='cosine',
linkage='average').fit(grads)
self.client_id2group = clustering.labels_
task_id = random.randint(0, len(self.pretrain_tasks) - 1)
tasks = [
self.pretrain_tasks[task_id] for _ in range(self.client_num)
]
else:
grads = torch.stack([
torch.cat([g.view(-1) for g in model['model_grads'].values()])
for model in models
])
distances = cosine_distances(grads, grads)
self.client_id2topk = [
dis[:k].tolist() for dis, k in zip(
np.argsort(distances, axis=-1), self.num_agg_topk)
]
self.client_id2all = np.argsort(distances, axis=-1).tolist()
return tasks
def _avg_params(self, models, client_adj_norm):
avg_model = copy.deepcopy([{
n: p
for n, p in model.state_dict().items()
if n in models[0]['model_grads']
} for model in self.models])
model_grads = copy.deepcopy([model['model_grads'] for model in models])
avg_grads = copy.deepcopy(model_grads)
for k in avg_model[0]:
for i in range(len(avg_model)):
for j in range(len(avg_model)):
weight = client_adj_norm[i][j]
local_grad = model_grads[j][k].float()
if j == 0:
avg_grads[i][k] = local_grad * weight
else:
avg_grads[i][k] += local_grad * weight
avg_model[i][k] = avg_model[i][k].float() + avg_grads[i][k]
return avg_model
def _para_weighted_avg(self, models, recover_fun=None):
tasks = [None for _ in range(self.client_num)]
if self.cfg.federate.method in ['local', 'global']:
model_params = {
'model_para': [model['model_para'] for model in models]
}
return model_params, tasks
if self.task == 'pretrain':
# generate self.client_id2group and param weight matrix
tasks = self._compute_client_groups(models)
group_id2client = {k: [] for k in range(self.num_agg_groups)}
for gid in range(self.num_agg_groups):
for cid in range(self.client_num):
if self.client_id2group[cid] == gid:
group_id2client[gid].append(cid)
logger.info('group_id2client: {}'.format({
k + 1: [x + 1 for x in v]
for k, v in group_id2client.items()
}))
client_adj = torch.zeros(self.client_num, self.client_num)
for i in range(self.client_num):
for j in range(self.client_num):
if self.client_id2group[i] == self.client_id2group[j]:
client_adj[i][j] = models[j]['sample_size'] * \
self.inside_weight
else:
client_adj[i][j] = models[j]['sample_size'] * \
self.outside_weight
client_adj_norm = client_adj / client_adj.sum(dim=-1, keepdim=True)
# aggregate model params
if not self.use_contrastive_loss:
model_params = {
'model_para': self._avg_params(models, client_adj_norm),
}
else:
model_params = {
'model_para': self._avg_params(models, client_adj_norm),
'contrast_monitor': self.contrast_monitor
}
else:
if not self.use_contrastive_loss:
# generate self.client_id2topk and param weight matrix
tasks = self._compute_client_groups(models)
logger.info('client_id2topk: {}'.format({
k + 1: [x + 1 for x in v] if v else v
for k, v in enumerate(self.client_id2topk)
}))
client_adj = torch.zeros(self.client_num, self.client_num)
for i in range(self.client_num):
for j in range(self.client_num):
if j in self.client_id2topk[i]:
client_adj[i][j] = models[j]['sample_size'] * \
self.inside_weight
else:
client_adj[i][j] = models[j]['sample_size'] * \
self.outside_weight
client_adj_norm = client_adj / client_adj.sum(dim=-1,
keepdim=True)
# aggregate model params
model_params = {
'model_para': self._avg_params(models, client_adj_norm)
}
else:
contrast_stat = models[0]['contrast_monitor'].stat
for model in models:
assert model['contrast_monitor'].stat == contrast_stat
self.contrast_monitor.update_stat(contrast_stat)
model_params = None
if contrast_stat == 2:
dec_hidden = [
model['contrast_monitor'].dec_hidden
for model in models
]
dec_out = [
model['contrast_monitor'].dec_out for model in models
]
dec_hidden = {k + 1: v for k, v in enumerate(dec_hidden)}
dec_out = {k + 1: v for k, v in enumerate(dec_out)}
all_group_ids = {
k + 1: [x + 1 for x in v]
for k, v in enumerate(self.client_id2all)
}
topk_group_ids = {
k + 1: [x + 1 for x in v]
for k, v in enumerate(self.client_id2topk)
}
self.contrast_monitor.update_dec_hidden(dec_hidden)
self.contrast_monitor.update_dec_out(dec_out)
self.contrast_monitor.update_all_group_ids(all_group_ids)
self.contrast_monitor.update_topk_group_ids(topk_group_ids)
elif contrast_stat == 3:
# generate self.client_id2topk and param weight matrix
tasks = self._compute_client_groups(models)
logger.info('client_id2all (n_topk={}): {}'.format(
self.num_agg_topk, {
k + 1: [x + 1 for x in v]
for k, v in enumerate(self.client_id2all)
}))
client_adj = torch.zeros(self.client_num, self.client_num)
for i in range(self.client_num):
for j in range(self.client_num):
if j in self.client_id2topk[i]:
client_adj[i][j] = models[j]['sample_size'] * \
self.inside_weight
else:
client_adj[i][j] = models[j]['sample_size'] * \
self.outside_weight
client_adj_norm = client_adj / client_adj.sum(dim=-1,
keepdim=True)
# aggregate model params
model_params = {
'model_para': self._avg_params(models, client_adj_norm)
}
return model_params, tasks