import torch import logging import copy from torch_geometric.loader import NeighborSampler from federatedscope.core.message import Message from federatedscope.core.workers.server import Server from federatedscope.core.workers.client import Client from federatedscope.core.auxiliaries.utils import merge_dict_of_results from federatedscope.core.data import ClientData from federatedscope.gfl.trainer.nodetrainer import NodeMiniBatchTrainer from federatedscope.gfl.model.fedsageplus import LocalSage_Plus, FedSage_Plus from federatedscope.gfl.fedsageplus.utils import GraphMender, HideGraph from federatedscope.gfl.fedsageplus.trainer import LocalGenTrainer, \ FedGenTrainer logger = logging.getLogger(__name__) class FedSagePlusServer(Server): def __init__(self, ID=-1, state=0, config=None, data=None, model=None, client_num=5, total_round_num=10, device='cpu', strategy=None, **kwargs): r""" FedSage+ consists of three of training stages. Stage1: 0, local pre-train for generator. Stage2: -> 2 * fedgen_epoch, federated training for generator. Stage3: -> 2 * fedgen_epoch + total_round_num: federated training for GraphSAGE Classifier """ super(FedSagePlusServer, self).__init__(ID, state, config, data, model, client_num, total_round_num, device, strategy, **kwargs) assert self.model_num == 1, "Not supported multi-model for " \ "FedSagePlusServer" # If state < fedgen_epoch and state % 2 == 0: # Server receive [model, embedding, label] # If state < fedgen_epoch and state % 2 == 1: # Server receive [gradient] self.fedgen_epoch = 2 * self._cfg.fedsageplus.fedgen_epoch self.total_round_num = total_round_num + self.fedgen_epoch self.grad_cnt = 0 def _register_default_handlers(self): self.register_handlers('join_in', self.callback_funcs_for_join_in) self.register_handlers('join_in_info', self.callback_funcs_for_join_in) self.register_handlers('clf_para', self.callback_funcs_model_para) self.register_handlers('gen_para', self.callback_funcs_model_para) self.register_handlers('gradient', self.callback_funcs_gradient) self.register_handlers('metrics', self.callback_funcs_for_metrics) def callback_funcs_for_join_in(self, message: Message): if 'info' in message.msg_type: sender, info = message.sender, message.content for key in self._cfg.federate.join_in_info: assert key in info self.join_in_info[sender] = info logger.info('Server: Client #{:d} has joined in !'.format(sender)) else: self.join_in_client_num += 1 sender, address = message.sender, message.content if int(sender) == -1: # assign number to client sender = self.join_in_client_num self.comm_manager.add_neighbors(neighbor_id=sender, address=address) self.comm_manager.send( Message(msg_type='assign_client_id', sender=self.ID, receiver=[sender], state=self.state, content=str(sender))) else: self.comm_manager.add_neighbors(neighbor_id=sender, address=address) if len(self._cfg.federate.join_in_info) != 0: self.comm_manager.send( Message(msg_type='ask_for_join_in_info', sender=self.ID, receiver=[sender], state=self.state, content=self._cfg.federate.join_in_info.copy())) if self.check_client_join_in(): if self._cfg.federate.use_ss: self.broadcast_client_address() self.comm_manager.send( Message(msg_type='local_pretrain', sender=self.ID, receiver=list(self.comm_manager.neighbors.keys()), state=self.state)) def callback_funcs_gradient(self, message: Message): round, _, content = message.state, message.sender, message.content gen_grad, ID = content # For a new round if round not in self.msg_buffer['train'].keys(): self.msg_buffer['train'][round] = dict() self.grad_cnt += 1 # Sum up all grad from other client if ID not in self.msg_buffer['train'][round]: self.msg_buffer['train'][round][ID] = dict() for key in gen_grad.keys(): self.msg_buffer['train'][round][ID][key] = torch.FloatTensor( gen_grad[key].cpu()) else: for key in gen_grad.keys(): self.msg_buffer['train'][round][ID][key] += torch.FloatTensor( gen_grad[key].cpu()) self.check_and_move_on() def check_and_move_on(self, check_eval_result=False): client_IDs = [i for i in range(1, self.client_num + 1)] if check_eval_result: # all clients are participating in evaluation minimal_number = self.client_num else: # sampled clients are participating in training minimal_number = self.sample_client_num # Transmit model and embedding to get gradient back if self.check_buffer( self.state, self.client_num ) and self.state < self._cfg.fedsageplus.fedgen_epoch and self.state\ % 2 == 0: # FedGen: we should wait for all messages for sender in self.msg_buffer['train'][self.state]: content = self.msg_buffer['train'][self.state][sender] gen_para, embedding, label = content receiver_IDs = client_IDs[:sender - 1] + client_IDs[sender:] self.comm_manager.send( Message(msg_type='gen_para', sender=self.ID, receiver=receiver_IDs, state=self.state + 1, content=[gen_para, embedding, label, sender])) logger.info(f'\tServer: Transmit gen_para to' f' {receiver_IDs} @{self.state//2}.') self.state += 1 # Sum up gradient client-wisely and send back if self.check_buffer( self.state, self.client_num ) and self.state < self._cfg.fedsageplus.fedgen_epoch and self.state\ % 2 == 1 and self.grad_cnt == self.client_num * ( self.client_num - 1): for ID in self.msg_buffer['train'][self.state]: grad = self.msg_buffer['train'][self.state][ID] self.comm_manager.send( Message(msg_type='gradient', sender=self.ID, receiver=[ID], state=self.state + 1, content=grad)) # reset num of grad counter self.grad_cnt = 0 self.state += 1 if self.check_buffer( self.state, self.client_num ) and self.state == self._cfg.fedsageplus.fedgen_epoch: self.state += 1 # Setup Clf_trainer for each client self.comm_manager.send( Message(msg_type='setup', sender=self.ID, receiver=list(self.comm_manager.neighbors.keys()), state=self.state)) if self.check_buffer( self.state, minimal_number, check_eval_result ) and self.state >= self._cfg.fedsageplus.fedgen_epoch: if not check_eval_result: # in the training process # Get all the message train_msg_buffer = self.msg_buffer['train'][self.state] msg_list = list() for client_id in train_msg_buffer: msg_list.append(train_msg_buffer[client_id]) # Trigger the monitor here (for training) self._monitor.calc_model_metric(self.model.state_dict(), msg_list, rnd=self.state) # Aggregate agg_info = { 'client_feedback': msg_list, 'recover_fun': self.recover_fun } result = self.aggregator.aggregate(agg_info) self.model.load_state_dict(result) self.aggregator.update(result) self.state += 1 if self.state % self._cfg.eval.freq == 0 and self.state != \ self.total_round_num: # Evaluate logger.info( 'Server : Starting evaluation at round {:d}.'.format( self.state)) self.eval() if self.state < self.total_round_num: # Move to next round of training logger.info( f'----------- Starting a new training round(Round ' f'#{self.state}) -------------') self.broadcast_model_para( msg_type='model_para', sample_client_num=self.sample_client_num) else: # Final Evaluate logger.info('Server: Training is finished! Starting ' 'evaluation.') self.eval() else: # in the evaluation process # Get all the message & aggregate formatted_eval_res = self.merge_eval_results_from_all_clients() self.history_results = merge_dict_of_results( self.history_results, formatted_eval_res) self.check_and_save() class FedSagePlusClient(Client): def __init__(self, ID=-1, server_id=None, state=-1, config=None, data=None, model=None, device='cpu', strategy=None, *args, **kwargs): super(FedSagePlusClient, self).__init__(ID, server_id, state, config, data, model, device, strategy, *args, **kwargs) self.data = data self.hide_data = HideGraph(self._cfg.fedsageplus.hide_portion)( data['data']) # Convert to `ClientData` self.hide_data = ClientData(self._cfg, train=[self.hide_data], val=[self.hide_data], test=[self.hide_data], data=self.hide_data) self.device = device self.sage_batch_size = 64 self.gen = LocalSage_Plus(data['data'].x.shape[-1], self._cfg.model.out_channels, hidden=self._cfg.model.hidden, gen_hidden=self._cfg.fedsageplus.gen_hidden, dropout=self._cfg.model.dropout, num_pred=self._cfg.fedsageplus.num_pred) self.clf = model self.trainer_loc = LocalGenTrainer(self.gen, self.hide_data, self.device, self._cfg, monitor=self._monitor) self.register_handlers('clf_para', self.callback_funcs_for_model_para) self.register_handlers('local_pretrain', self.callback_funcs_for_local_pre_train) self.register_handlers('gradient', self.callback_funcs_for_gradient) self.register_handlers('gen_para', self.callback_funcs_for_gen_para) self.register_handlers('setup', self.callback_funcs_for_setup_fedsage) def callback_funcs_for_local_pre_train(self, message: Message): round, sender, _ = message.state, message.sender, message.content # Local pre-train logger.info(f'\tClient #{self.ID} pre-train start...') for i in range(self._cfg.fedsageplus.loc_epoch): num_samples_train, _, _ = self.trainer_loc.train() logger.info(f'\tClient #{self.ID} local pre-train @Epoch {i}.') # Build fedgen base on locgen self.fedgen = FedSage_Plus(self.gen) # Build trainer for fedgen self.trainer_fedgen = FedGenTrainer(self.fedgen, self.hide_data, self.device, self._cfg, monitor=self._monitor) gen_para = self.fedgen.cpu().state_dict() embedding = self.trainer_fedgen.embedding() self.state = round logger.info(f'\tClient #{self.ID} pre-train finish!') # Start the training of fedgen self.comm_manager.send( Message(msg_type='gen_para', sender=self.ID, receiver=[sender], state=self.state, content=[ gen_para, embedding, self.hide_data['data'].num_missing ])) logger.info(f'\tClient #{self.ID} send gen_para to Server #{sender}.') def callback_funcs_for_gen_para(self, message: Message): round, sender, content = message.state, message.sender, message.content gen_para, embedding, label, ID = content gen_grad = self.trainer_fedgen.cal_grad(self.data['data'], gen_para, embedding, label) self.state = round self.comm_manager.send( Message(msg_type='gradient', sender=self.ID, receiver=[sender], state=self.state, content=[gen_grad, ID])) logger.info(f'\tClient #{self.ID}: send gradient to Server #{sender}.') def callback_funcs_for_gradient(self, message): # Aggregate gen_grad on server round, sender, content = message.state, message.sender, message.content gen_grad = content self.trainer_fedgen.train() gen_para = self.trainer_fedgen.update_by_grad(gen_grad) embedding = self.trainer_fedgen.embedding() self.state = round self.comm_manager.send( Message(msg_type='gen_para', sender=self.ID, receiver=[sender], state=self.state, content=[ gen_para, embedding, self.hide_data['data'].num_missing ])) logger.info(f'\tClient #{self.ID}: send gen_para to Server #{sender}.') def callback_funcs_for_setup_fedsage(self, message: Message): round, sender, _ = message.state, message.sender, message.content self.filled_data = GraphMender( model=self.fedgen, impaired_data=self.hide_data['data'].cpu(), original_data=self.data['data']) subgraph_sampler = NeighborSampler( self.filled_data.edge_index, sizes=[-1], batch_size=4096, shuffle=False, num_workers=self._cfg.dataloader.num_workers) fill_dataloader = { 'data': self.filled_data, 'train': NeighborSampler( self.filled_data.edge_index, node_idx=self.filled_data.train_idx, sizes=self._cfg.dataloader.sizes, batch_size=self.sage_batch_size, shuffle=self._cfg.dataloader.shuffle, num_workers=self._cfg.dataloader.num_workers), 'val': subgraph_sampler, 'test': subgraph_sampler } self._cfg.merge_from_list( ['dataloader.batch_size', self.sage_batch_size]) self.trainer_clf = NodeMiniBatchTrainer(self.clf, fill_dataloader, self.device, self._cfg, monitor=self._monitor) sample_size, clf_para, results = self.trainer_clf.train() self.state = round logger.info( self._monitor.format_eval_res(results, rnd=self.state, role='Client #{}'.format(self.ID))) self.comm_manager.send( Message(msg_type='clf_para', sender=self.ID, receiver=[sender], state=self.state, content=(sample_size, clf_para))) def callback_funcs_for_model_para(self, message: Message): round, sender, content = message.state, message.sender, message.content self.trainer_clf.update(content) self.state = round sample_size, clf_para, results = self.trainer_clf.train() if self._cfg.federate.share_local_model and not \ self._cfg.federate.online_aggr: clf_para = copy.deepcopy(clf_para) logger.info( self._monitor.format_eval_res(results, rnd=self.state, role='Client #{}'.format(self.ID))) self.comm_manager.send( Message(msg_type='clf_para', sender=self.ID, receiver=[sender], state=self.state, content=(sample_size, clf_para)))