FS-TFP/federatedscope/gfl/gcflplus/worker.py

215 lines
9.2 KiB
Python

import torch
import logging
import copy
import numpy as np
from federatedscope.core.message import Message
from federatedscope.core.workers.server import Server
from federatedscope.core.workers.client import Client
from federatedscope.core.auxiliaries.utils import merge_dict_of_results
from federatedscope.gfl.gcflplus.utils import compute_pairwise_distances, \
min_cut, norm
logger = logging.getLogger(__name__)
class GCFLPlusServer(Server):
def __init__(self,
ID=-1,
state=0,
config=None,
data=None,
model=None,
client_num=5,
total_round_num=10,
device='cpu',
strategy=None,
**kwargs):
super(GCFLPlusServer,
self).__init__(ID, state, config, data, model, client_num,
total_round_num, device, strategy, **kwargs)
# Initial cluster
self.cluster_indices = [
np.arange(1, self._cfg.federate.client_num + 1).astype("int")
]
self.client_clusters = [[ID for ID in cluster_id]
for cluster_id in self.cluster_indices]
# Maintain a grad sequence
self.seqs_grads = {
idx: []
for idx in range(1, self._cfg.federate.client_num + 1)
}
def compute_update_norm(self, cluster):
max_norm = -np.inf
cluster_dWs = []
for key in cluster:
content = self.msg_buffer['train'][self.state][key]
_, model_para, client_dw, _ = content
dW = {}
for k in model_para.keys():
dW[k] = client_dw[k]
update_norm = norm(dW)
if update_norm > max_norm:
max_norm = update_norm
cluster_dWs.append(
torch.cat([value.flatten() for value in dW.values()]))
mean_norm = torch.norm(torch.mean(torch.stack(cluster_dWs),
dim=0)).item()
return max_norm, mean_norm
def check_and_move_on(self, check_eval_result=False):
if check_eval_result:
# all clients are participating in evaluation
minimal_number = self.client_num
else:
# sampled clients are participating in training
minimal_number = self.sample_client_num
if self.check_buffer(self.state, minimal_number, check_eval_result):
if not check_eval_result: # in the training process
# Get all the message
train_msg_buffer = self.msg_buffer['train'][self.state]
for model_idx in range(self.model_num):
model = self.models[model_idx]
aggregator = self.aggregators[model_idx]
msg_list = list()
for client_id in train_msg_buffer:
if self.model_num == 1:
train_data_size, model_para, _, convGradsNorm = \
train_msg_buffer[client_id]
self.seqs_grads[client_id].append(convGradsNorm)
msg_list.append((train_data_size, model_para))
else:
raise ValueError(
'GCFL server not support multi-model.')
cluster_indices_new = []
for cluster in self.cluster_indices:
max_norm, mean_norm = self.compute_update_norm(cluster)
# create new cluster
if mean_norm < self._cfg.gcflplus.EPS_1 and max_norm\
> self._cfg.gcflplus.EPS_2 and len(
cluster) > 2 and self.state > 20 and all(
len(value) >= self._cfg.gcflplus.seq_length
for value in self.seqs_grads.values()):
_, model_para_cluster, _, _ = self.msg_buffer[
'train'][self.state][cluster[0]]
tmp = [
self.seqs_grads[ID]
[-self._cfg.gcflplus.seq_length:]
for ID in cluster
]
dtw_distances = compute_pairwise_distances(
tmp, self._cfg.gcflplus.standardize)
c1, c2 = min_cut(
np.max(dtw_distances) - dtw_distances, cluster)
cluster_indices_new += [c1, c2]
# reset seqs_grads for all clients
self.seqs_grads = {
idx: []
for idx in range(
1, self._cfg.federate.client_num + 1)
}
# keep this cluster
else:
cluster_indices_new += [cluster]
self.cluster_indices = cluster_indices_new
self.client_clusters = [[
ID for ID in cluster_id
] for cluster_id in self.cluster_indices]
self.state += 1
if self.state % self._cfg.eval.freq == 0 and self.state != \
self.total_round_num:
# Evaluate
logger.info(
'Server: Starting evaluation at round {:d}.'.format(
self.state))
self.eval()
if self.state < self.total_round_num:
for cluster in self.cluster_indices:
msg_list = list()
for key in cluster:
content = self.msg_buffer['train'][self.state -
1][key]
train_data_size, model_para, client_dw, \
convGradsNorm = content
msg_list.append((train_data_size, model_para))
agg_info = {
'client_feedback': msg_list,
'recover_fun': self.recover_fun
}
result = aggregator.aggregate(agg_info)
model.load_state_dict(result, strict=False)
# aggregator.update(result)
# Send to Clients
self.comm_manager.send(
Message(msg_type='model_para',
sender=self.ID,
receiver=cluster.tolist(),
state=self.state,
content=result))
# Move to next round of training
logger.info(
f'----------- Starting a new traininground(Round '
f'#{self.state}) -------------')
# Clean the msg_buffer
self.msg_buffer['train'][self.state - 1].clear()
else:
# Final Evaluate
logger.info('Server: Training is finished! Starting '
'evaluation.')
self.eval()
else: # in the evaluation process
# Get all the message & aggregate
formatted_eval_res = self.merge_eval_results_from_all_clients()
self.history_results = merge_dict_of_results(
self.history_results, formatted_eval_res)
self.check_and_save()
class GCFLPlusClient(Client):
def callback_funcs_for_model_para(self, message: Message):
round, sender, content = message.state, message.sender, message.content
# Cache old W
W_old = copy.deepcopy(content)
self.trainer.update(content)
self.state = round
sample_size, model_para, results = self.trainer.train()
if self._cfg.federate.share_local_model and not \
self._cfg.federate.online_aggr:
model_para = copy.deepcopy(model_para)
logger.info(
self._monitor.format_eval_res(results,
rnd=self.state,
role='Client #{}'.format(self.ID)))
# Compute norm of W & norm of grad
dW = dict()
for key in model_para.keys():
dW[key] = model_para[key] - W_old[key].cpu()
self.W = {key: value for key, value in self.model.named_parameters()}
convGradsNorm = dict()
for key in model_para.keys():
if key in self.W and self.W[key].grad is not None:
convGradsNorm[key] = self.W[key].grad
convGradsNorm = norm(convGradsNorm)
self.comm_manager.send(
Message(msg_type='model_para',
sender=self.ID,
receiver=[sender],
state=self.state,
content=(sample_size, model_para, dW, convGradsNorm)))