255 lines
10 KiB
Python
255 lines
10 KiB
Python
import torch
|
|
import logging
|
|
import copy
|
|
import numpy as np
|
|
|
|
from federatedscope.core.message import Message
|
|
from federatedscope.core.workers.server import Server
|
|
from federatedscope.core.auxiliaries.utils import merge_dict
|
|
from federatedscope.cl.fedgc.utils import global_NT_xentloss
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class GlobalContrastFLServer(Server):
|
|
r"""
|
|
GlobalContrastFL(Fedgc) Server contain two part in training: Fedavg
|
|
aggragator for client model weight and calculate global loss from
|
|
all sampled client embedding then broadcast all client to train model.
|
|
"""
|
|
def __init__(self,
|
|
ID=-1,
|
|
state=0,
|
|
config=None,
|
|
data=None,
|
|
model=None,
|
|
client_num=5,
|
|
total_round_num=10,
|
|
device='cpu',
|
|
strategy=None,
|
|
**kwargs):
|
|
super(GlobalContrastFLServer,
|
|
self).__init__(ID, state, config, data, model, client_num,
|
|
total_round_num, device, strategy, **kwargs)
|
|
# Initial seqs_embedding
|
|
self.seqs_embedding = {
|
|
idx: ()
|
|
for idx in range(1, self._cfg.federate.client_num + 1)
|
|
}
|
|
self.loss_list = {
|
|
idx: 0
|
|
for idx in range(1, self._cfg.federate.client_num + 1)
|
|
}
|
|
|
|
def _register_default_handlers(self):
|
|
self.register_handlers('join_in', self.callback_funcs_for_join_in)
|
|
self.register_handlers('join_in_info', self.callback_funcs_for_join_in)
|
|
self.register_handlers('model_para', self.callback_funcs_model_para)
|
|
self.register_handlers('metrics', self.callback_funcs_for_metrics)
|
|
self.register_handlers('pred_embedding',
|
|
self.callback_funcs_global_loss)
|
|
|
|
def check_and_move_on_for_global_loss(self):
|
|
|
|
minimal_number = self.sample_client_num
|
|
|
|
if self.check_buffer(self.state,
|
|
minimal_number,
|
|
check_eval_result=False):
|
|
|
|
# Receiving enough feedback in the training process
|
|
|
|
# Get all the message
|
|
train_msg_buffer = self.msg_buffer['train'][self.state]
|
|
for model_idx in range(self.model_num):
|
|
model = self.models[model_idx]
|
|
msg_list = list()
|
|
for client_id in train_msg_buffer:
|
|
if self.model_num == 1:
|
|
pred_embedding = train_msg_buffer[client_id]
|
|
self.seqs_embedding[client_id] = pred_embedding
|
|
else:
|
|
raise ValueError(
|
|
'GlobalContrastFL server not support multi-model.')
|
|
|
|
global_loss_fn = global_NT_xentloss(device=self.device)
|
|
for client_id in train_msg_buffer:
|
|
z1 = self.seqs_embedding[client_id][0]
|
|
z2 = self.seqs_embedding[client_id][1]
|
|
others_z2 = [
|
|
self.seqs_embedding[other_client_id][1]
|
|
for other_client_id in train_msg_buffer
|
|
if other_client_id != client_id
|
|
]
|
|
self.loss_list[client_id] = global_loss_fn(
|
|
z1, z2, others_z2)
|
|
logger.info(f'client {client_id}'
|
|
f'global_loss:{self.loss_list[client_id]}')
|
|
|
|
self.state += 1
|
|
if self.state <= self.total_round_num:
|
|
|
|
for client_id in train_msg_buffer:
|
|
|
|
msg_list = {
|
|
'global_loss': self.loss_list[client_id],
|
|
}
|
|
|
|
self.comm_manager.send(
|
|
Message(msg_type='global_loss',
|
|
sender=self.ID,
|
|
receiver=[client_id],
|
|
state=self.state,
|
|
content=msg_list))
|
|
|
|
def check_and_move_on(self,
|
|
check_eval_result=False,
|
|
min_received_num=None):
|
|
"""
|
|
To check the message_buffer. When enough messages are receiving,
|
|
some events (such as perform aggregation, evaluation, and move to
|
|
the next training round) would be triggered.
|
|
|
|
Arguments:
|
|
check_eval_result (bool): If True, check the message buffer for
|
|
evaluation; and check the message buffer for training otherwise.
|
|
"""
|
|
if min_received_num is None:
|
|
if self._cfg.asyn.use:
|
|
min_received_num = self._cfg.asyn.min_received_num
|
|
else:
|
|
min_received_num = self._cfg.federate.sample_client_num
|
|
assert min_received_num <= self.sample_client_num
|
|
|
|
if check_eval_result and self._cfg.federate.mode.lower(
|
|
) == "standalone":
|
|
# in evaluation stage and standalone simulation mode, we assume
|
|
# strong synchronization that receives responses from all clients
|
|
min_received_num = len(self.comm_manager.get_neighbors().keys())
|
|
|
|
move_on_flag = True # To record whether moving to a new training
|
|
# round or finishing the evaluation
|
|
if self.check_buffer(self.state, min_received_num, check_eval_result):
|
|
if not check_eval_result:
|
|
# Receiving enough feedback in the training process
|
|
aggregated_num = self._perform_federated_aggregation()
|
|
|
|
if self.state % self._cfg.eval.freq == 0 and self.state != \
|
|
self.total_round_num:
|
|
# Evaluate
|
|
logger.info(f'Server: Starting evaluation at the end '
|
|
f'of round {self.state - 1}.')
|
|
self.eval()
|
|
|
|
if self.state < self.total_round_num:
|
|
# Move to next round of training
|
|
logger.info(
|
|
f'----------- Starting a new training round (Round '
|
|
f'#{self.state}) -------------')
|
|
# Clean the msg_buffer
|
|
self.msg_buffer['train'][self.state - 1].clear()
|
|
self.msg_buffer['train'][self.state] = dict()
|
|
self.staled_msg_buffer.clear()
|
|
# Start a new training round
|
|
self._start_new_training_round(aggregated_num)
|
|
else:
|
|
# Final Evaluate
|
|
logger.info('Server: Training is finished! Starting '
|
|
'evaluation.')
|
|
self.eval()
|
|
|
|
else:
|
|
# Receiving enough feedback in the evaluation process
|
|
self._merge_and_format_eval_results()
|
|
|
|
else:
|
|
move_on_flag = False
|
|
|
|
return move_on_flag
|
|
|
|
def callback_funcs_global_loss(self, message: Message):
|
|
"""
|
|
The handling function for receiving model embeddings, which triggers
|
|
check_and_move_on (calculate global loss when enough feedback has
|
|
been received).
|
|
|
|
Arguments:
|
|
message: The received message, which includes sender, receiver,
|
|
state, and content. More detail can be found in
|
|
federatedscope.core.message
|
|
"""
|
|
if self.is_finish:
|
|
return 'finish'
|
|
|
|
round = message.state
|
|
sender = message.sender
|
|
timestamp = message.timestamp
|
|
content = message.content
|
|
self.sampler.change_state(sender, 'idle')
|
|
|
|
# update the currency timestamp according to the received message
|
|
assert timestamp >= self.cur_timestamp # for test
|
|
self.cur_timestamp = timestamp
|
|
|
|
if round == self.state:
|
|
if round not in self.msg_buffer['train']:
|
|
self.msg_buffer['train'][round] = dict()
|
|
# Save the messages in this round
|
|
self.msg_buffer['train'][round][sender] = content
|
|
elif round >= self.state - self.staleness_toleration:
|
|
# Save the staled messages
|
|
self.staled_msg_buffer.append((round, sender, content))
|
|
|
|
move_on_flag = self.check_and_move_on_for_global_loss()
|
|
|
|
return move_on_flag
|
|
|
|
def callback_funcs_model_para(self, message: Message):
|
|
"""
|
|
The handling function for receiving model parameters, which triggers
|
|
check_and_move_on (perform aggregation when enough feedback has
|
|
been received).
|
|
This handling function is widely used in various FL courses.
|
|
|
|
Arguments:
|
|
message: The received message, which includes sender, receiver,
|
|
state, and content. More detail can be found in
|
|
federatedscope.core.message
|
|
"""
|
|
if self.is_finish:
|
|
return 'finish'
|
|
|
|
round = message.state
|
|
sender = message.sender
|
|
timestamp = message.timestamp
|
|
content = message.content
|
|
self.sampler.change_state(sender, 'idle')
|
|
|
|
# update the currency timestamp according to the received message
|
|
assert timestamp >= self.cur_timestamp # for test
|
|
self.cur_timestamp = timestamp
|
|
|
|
if round == self.state:
|
|
if round not in self.msg_buffer['train']:
|
|
self.msg_buffer['train'][round] = dict()
|
|
# Save the messages in this round
|
|
self.msg_buffer['train'][round][sender] = content
|
|
elif round >= self.state - self.staleness_toleration:
|
|
# Save the staled messages
|
|
self.staled_msg_buffer.append((round, sender, content))
|
|
else:
|
|
# Drop the out-of-date messages
|
|
logger.info(f'Drop a out-of-date message from round #{round}')
|
|
self.dropout_num += 1
|
|
|
|
if self._cfg.federate.online_aggr:
|
|
self.aggregator.inc(content[:2])
|
|
|
|
move_on_flag = self.check_and_move_on()
|
|
if self._cfg.asyn.use and self._cfg.asyn.broadcast_manner == \
|
|
'after_receiving':
|
|
self.broadcast_model_para(msg_type='model_para',
|
|
sample_client_num=1)
|
|
|
|
return move_on_flag
|