FS-TFP/federatedscope/cl/fedgc/server.py

255 lines
10 KiB
Python

import torch
import logging
import copy
import numpy as np
from federatedscope.core.message import Message
from federatedscope.core.workers.server import Server
from federatedscope.core.auxiliaries.utils import merge_dict
from federatedscope.cl.fedgc.utils import global_NT_xentloss
logger = logging.getLogger(__name__)
class GlobalContrastFLServer(Server):
r"""
GlobalContrastFL(Fedgc) Server contain two part in training: Fedavg
aggragator for client model weight and calculate global loss from
all sampled client embedding then broadcast all client to train model.
"""
def __init__(self,
ID=-1,
state=0,
config=None,
data=None,
model=None,
client_num=5,
total_round_num=10,
device='cpu',
strategy=None,
**kwargs):
super(GlobalContrastFLServer,
self).__init__(ID, state, config, data, model, client_num,
total_round_num, device, strategy, **kwargs)
# Initial seqs_embedding
self.seqs_embedding = {
idx: ()
for idx in range(1, self._cfg.federate.client_num + 1)
}
self.loss_list = {
idx: 0
for idx in range(1, self._cfg.federate.client_num + 1)
}
def _register_default_handlers(self):
self.register_handlers('join_in', self.callback_funcs_for_join_in)
self.register_handlers('join_in_info', self.callback_funcs_for_join_in)
self.register_handlers('model_para', self.callback_funcs_model_para)
self.register_handlers('metrics', self.callback_funcs_for_metrics)
self.register_handlers('pred_embedding',
self.callback_funcs_global_loss)
def check_and_move_on_for_global_loss(self):
minimal_number = self.sample_client_num
if self.check_buffer(self.state,
minimal_number,
check_eval_result=False):
# Receiving enough feedback in the training process
# Get all the message
train_msg_buffer = self.msg_buffer['train'][self.state]
for model_idx in range(self.model_num):
model = self.models[model_idx]
msg_list = list()
for client_id in train_msg_buffer:
if self.model_num == 1:
pred_embedding = train_msg_buffer[client_id]
self.seqs_embedding[client_id] = pred_embedding
else:
raise ValueError(
'GlobalContrastFL server not support multi-model.')
global_loss_fn = global_NT_xentloss(device=self.device)
for client_id in train_msg_buffer:
z1 = self.seqs_embedding[client_id][0]
z2 = self.seqs_embedding[client_id][1]
others_z2 = [
self.seqs_embedding[other_client_id][1]
for other_client_id in train_msg_buffer
if other_client_id != client_id
]
self.loss_list[client_id] = global_loss_fn(
z1, z2, others_z2)
logger.info(f'client {client_id}'
f'global_loss:{self.loss_list[client_id]}')
self.state += 1
if self.state <= self.total_round_num:
for client_id in train_msg_buffer:
msg_list = {
'global_loss': self.loss_list[client_id],
}
self.comm_manager.send(
Message(msg_type='global_loss',
sender=self.ID,
receiver=[client_id],
state=self.state,
content=msg_list))
def check_and_move_on(self,
check_eval_result=False,
min_received_num=None):
"""
To check the message_buffer. When enough messages are receiving,
some events (such as perform aggregation, evaluation, and move to
the next training round) would be triggered.
Arguments:
check_eval_result (bool): If True, check the message buffer for
evaluation; and check the message buffer for training otherwise.
"""
if min_received_num is None:
if self._cfg.asyn.use:
min_received_num = self._cfg.asyn.min_received_num
else:
min_received_num = self._cfg.federate.sample_client_num
assert min_received_num <= self.sample_client_num
if check_eval_result and self._cfg.federate.mode.lower(
) == "standalone":
# in evaluation stage and standalone simulation mode, we assume
# strong synchronization that receives responses from all clients
min_received_num = len(self.comm_manager.get_neighbors().keys())
move_on_flag = True # To record whether moving to a new training
# round or finishing the evaluation
if self.check_buffer(self.state, min_received_num, check_eval_result):
if not check_eval_result:
# Receiving enough feedback in the training process
aggregated_num = self._perform_federated_aggregation()
if self.state % self._cfg.eval.freq == 0 and self.state != \
self.total_round_num:
# Evaluate
logger.info(f'Server: Starting evaluation at the end '
f'of round {self.state - 1}.')
self.eval()
if self.state < self.total_round_num:
# Move to next round of training
logger.info(
f'----------- Starting a new training round (Round '
f'#{self.state}) -------------')
# Clean the msg_buffer
self.msg_buffer['train'][self.state - 1].clear()
self.msg_buffer['train'][self.state] = dict()
self.staled_msg_buffer.clear()
# Start a new training round
self._start_new_training_round(aggregated_num)
else:
# Final Evaluate
logger.info('Server: Training is finished! Starting '
'evaluation.')
self.eval()
else:
# Receiving enough feedback in the evaluation process
self._merge_and_format_eval_results()
else:
move_on_flag = False
return move_on_flag
def callback_funcs_global_loss(self, message: Message):
"""
The handling function for receiving model embeddings, which triggers
check_and_move_on (calculate global loss when enough feedback has
been received).
Arguments:
message: The received message, which includes sender, receiver,
state, and content. More detail can be found in
federatedscope.core.message
"""
if self.is_finish:
return 'finish'
round = message.state
sender = message.sender
timestamp = message.timestamp
content = message.content
self.sampler.change_state(sender, 'idle')
# update the currency timestamp according to the received message
assert timestamp >= self.cur_timestamp # for test
self.cur_timestamp = timestamp
if round == self.state:
if round not in self.msg_buffer['train']:
self.msg_buffer['train'][round] = dict()
# Save the messages in this round
self.msg_buffer['train'][round][sender] = content
elif round >= self.state - self.staleness_toleration:
# Save the staled messages
self.staled_msg_buffer.append((round, sender, content))
move_on_flag = self.check_and_move_on_for_global_loss()
return move_on_flag
def callback_funcs_model_para(self, message: Message):
"""
The handling function for receiving model parameters, which triggers
check_and_move_on (perform aggregation when enough feedback has
been received).
This handling function is widely used in various FL courses.
Arguments:
message: The received message, which includes sender, receiver,
state, and content. More detail can be found in
federatedscope.core.message
"""
if self.is_finish:
return 'finish'
round = message.state
sender = message.sender
timestamp = message.timestamp
content = message.content
self.sampler.change_state(sender, 'idle')
# update the currency timestamp according to the received message
assert timestamp >= self.cur_timestamp # for test
self.cur_timestamp = timestamp
if round == self.state:
if round not in self.msg_buffer['train']:
self.msg_buffer['train'][round] = dict()
# Save the messages in this round
self.msg_buffer['train'][round][sender] = content
elif round >= self.state - self.staleness_toleration:
# Save the staled messages
self.staled_msg_buffer.append((round, sender, content))
else:
# Drop the out-of-date messages
logger.info(f'Drop a out-of-date message from round #{round}')
self.dropout_num += 1
if self._cfg.federate.online_aggr:
self.aggregator.inc(content[:2])
move_on_flag = self.check_and_move_on()
if self._cfg.asyn.use and self._cfg.asyn.broadcast_manner == \
'after_receiving':
self.broadcast_model_para(msg_type='model_para',
sample_client_num=1)
return move_on_flag