215 lines
9.2 KiB
Python
215 lines
9.2 KiB
Python
import torch
|
|
import logging
|
|
import copy
|
|
import numpy as np
|
|
|
|
from federatedscope.core.message import Message
|
|
from federatedscope.core.workers.server import Server
|
|
from federatedscope.core.workers.client import Client
|
|
from federatedscope.core.auxiliaries.utils import merge_dict_of_results
|
|
from federatedscope.gfl.gcflplus.utils import compute_pairwise_distances, \
|
|
min_cut, norm
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class GCFLPlusServer(Server):
|
|
def __init__(self,
|
|
ID=-1,
|
|
state=0,
|
|
config=None,
|
|
data=None,
|
|
model=None,
|
|
client_num=5,
|
|
total_round_num=10,
|
|
device='cpu',
|
|
strategy=None,
|
|
**kwargs):
|
|
super(GCFLPlusServer,
|
|
self).__init__(ID, state, config, data, model, client_num,
|
|
total_round_num, device, strategy, **kwargs)
|
|
# Initial cluster
|
|
self.cluster_indices = [
|
|
np.arange(1, self._cfg.federate.client_num + 1).astype("int")
|
|
]
|
|
self.client_clusters = [[ID for ID in cluster_id]
|
|
for cluster_id in self.cluster_indices]
|
|
# Maintain a grad sequence
|
|
self.seqs_grads = {
|
|
idx: []
|
|
for idx in range(1, self._cfg.federate.client_num + 1)
|
|
}
|
|
|
|
def compute_update_norm(self, cluster):
|
|
max_norm = -np.inf
|
|
cluster_dWs = []
|
|
for key in cluster:
|
|
content = self.msg_buffer['train'][self.state][key]
|
|
_, model_para, client_dw, _ = content
|
|
dW = {}
|
|
for k in model_para.keys():
|
|
dW[k] = client_dw[k]
|
|
update_norm = norm(dW)
|
|
if update_norm > max_norm:
|
|
max_norm = update_norm
|
|
cluster_dWs.append(
|
|
torch.cat([value.flatten() for value in dW.values()]))
|
|
mean_norm = torch.norm(torch.mean(torch.stack(cluster_dWs),
|
|
dim=0)).item()
|
|
return max_norm, mean_norm
|
|
|
|
def check_and_move_on(self, check_eval_result=False):
|
|
|
|
if check_eval_result:
|
|
# all clients are participating in evaluation
|
|
minimal_number = self.client_num
|
|
else:
|
|
# sampled clients are participating in training
|
|
minimal_number = self.sample_client_num
|
|
|
|
if self.check_buffer(self.state, minimal_number, check_eval_result):
|
|
|
|
if not check_eval_result: # in the training process
|
|
# Get all the message
|
|
train_msg_buffer = self.msg_buffer['train'][self.state]
|
|
for model_idx in range(self.model_num):
|
|
model = self.models[model_idx]
|
|
aggregator = self.aggregators[model_idx]
|
|
msg_list = list()
|
|
for client_id in train_msg_buffer:
|
|
if self.model_num == 1:
|
|
train_data_size, model_para, _, convGradsNorm = \
|
|
train_msg_buffer[client_id]
|
|
self.seqs_grads[client_id].append(convGradsNorm)
|
|
msg_list.append((train_data_size, model_para))
|
|
else:
|
|
raise ValueError(
|
|
'GCFL server not support multi-model.')
|
|
|
|
cluster_indices_new = []
|
|
for cluster in self.cluster_indices:
|
|
max_norm, mean_norm = self.compute_update_norm(cluster)
|
|
# create new cluster
|
|
if mean_norm < self._cfg.gcflplus.EPS_1 and max_norm\
|
|
> self._cfg.gcflplus.EPS_2 and len(
|
|
cluster) > 2 and self.state > 20 and all(
|
|
len(value) >= self._cfg.gcflplus.seq_length
|
|
for value in self.seqs_grads.values()):
|
|
_, model_para_cluster, _, _ = self.msg_buffer[
|
|
'train'][self.state][cluster[0]]
|
|
tmp = [
|
|
self.seqs_grads[ID]
|
|
[-self._cfg.gcflplus.seq_length:]
|
|
for ID in cluster
|
|
]
|
|
dtw_distances = compute_pairwise_distances(
|
|
tmp, self._cfg.gcflplus.standardize)
|
|
c1, c2 = min_cut(
|
|
np.max(dtw_distances) - dtw_distances, cluster)
|
|
cluster_indices_new += [c1, c2]
|
|
# reset seqs_grads for all clients
|
|
self.seqs_grads = {
|
|
idx: []
|
|
for idx in range(
|
|
1, self._cfg.federate.client_num + 1)
|
|
}
|
|
# keep this cluster
|
|
else:
|
|
cluster_indices_new += [cluster]
|
|
|
|
self.cluster_indices = cluster_indices_new
|
|
self.client_clusters = [[
|
|
ID for ID in cluster_id
|
|
] for cluster_id in self.cluster_indices]
|
|
|
|
self.state += 1
|
|
if self.state % self._cfg.eval.freq == 0 and self.state != \
|
|
self.total_round_num:
|
|
# Evaluate
|
|
logger.info(
|
|
'Server: Starting evaluation at round {:d}.'.format(
|
|
self.state))
|
|
self.eval()
|
|
|
|
if self.state < self.total_round_num:
|
|
for cluster in self.cluster_indices:
|
|
msg_list = list()
|
|
for key in cluster:
|
|
content = self.msg_buffer['train'][self.state -
|
|
1][key]
|
|
train_data_size, model_para, client_dw, \
|
|
convGradsNorm = content
|
|
msg_list.append((train_data_size, model_para))
|
|
|
|
agg_info = {
|
|
'client_feedback': msg_list,
|
|
'recover_fun': self.recover_fun
|
|
}
|
|
result = aggregator.aggregate(agg_info)
|
|
model.load_state_dict(result, strict=False)
|
|
# aggregator.update(result)
|
|
# Send to Clients
|
|
self.comm_manager.send(
|
|
Message(msg_type='model_para',
|
|
sender=self.ID,
|
|
receiver=cluster.tolist(),
|
|
state=self.state,
|
|
content=result))
|
|
|
|
# Move to next round of training
|
|
logger.info(
|
|
f'----------- Starting a new traininground(Round '
|
|
f'#{self.state}) -------------')
|
|
# Clean the msg_buffer
|
|
self.msg_buffer['train'][self.state - 1].clear()
|
|
|
|
else:
|
|
# Final Evaluate
|
|
logger.info('Server: Training is finished! Starting '
|
|
'evaluation.')
|
|
self.eval()
|
|
|
|
else: # in the evaluation process
|
|
# Get all the message & aggregate
|
|
formatted_eval_res = self.merge_eval_results_from_all_clients()
|
|
self.history_results = merge_dict_of_results(
|
|
self.history_results, formatted_eval_res)
|
|
self.check_and_save()
|
|
|
|
|
|
class GCFLPlusClient(Client):
|
|
def callback_funcs_for_model_para(self, message: Message):
|
|
round, sender, content = message.state, message.sender, message.content
|
|
# Cache old W
|
|
W_old = copy.deepcopy(content)
|
|
self.trainer.update(content)
|
|
self.state = round
|
|
sample_size, model_para, results = self.trainer.train()
|
|
if self._cfg.federate.share_local_model and not \
|
|
self._cfg.federate.online_aggr:
|
|
model_para = copy.deepcopy(model_para)
|
|
logger.info(
|
|
self._monitor.format_eval_res(results,
|
|
rnd=self.state,
|
|
role='Client #{}'.format(self.ID)))
|
|
|
|
# Compute norm of W & norm of grad
|
|
dW = dict()
|
|
for key in model_para.keys():
|
|
dW[key] = model_para[key] - W_old[key].cpu()
|
|
|
|
self.W = {key: value for key, value in self.model.named_parameters()}
|
|
|
|
convGradsNorm = dict()
|
|
for key in model_para.keys():
|
|
if key in self.W and self.W[key].grad is not None:
|
|
convGradsNorm[key] = self.W[key].grad
|
|
convGradsNorm = norm(convGradsNorm)
|
|
|
|
self.comm_manager.send(
|
|
Message(msg_type='model_para',
|
|
sender=self.ID,
|
|
receiver=[sender],
|
|
state=self.state,
|
|
content=(sample_size, model_para, dW, convGradsNorm)))
|