412 lines
18 KiB
Python
412 lines
18 KiB
Python
import torch
|
|
import logging
|
|
import copy
|
|
|
|
from torch_geometric.loader import NeighborSampler
|
|
|
|
from federatedscope.core.message import Message
|
|
from federatedscope.core.workers.server import Server
|
|
from federatedscope.core.workers.client import Client
|
|
from federatedscope.core.auxiliaries.utils import merge_dict_of_results
|
|
from federatedscope.core.data import ClientData
|
|
|
|
from federatedscope.gfl.trainer.nodetrainer import NodeMiniBatchTrainer
|
|
from federatedscope.gfl.model.fedsageplus import LocalSage_Plus, FedSage_Plus
|
|
from federatedscope.gfl.fedsageplus.utils import GraphMender, HideGraph
|
|
from federatedscope.gfl.fedsageplus.trainer import LocalGenTrainer, \
|
|
FedGenTrainer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FedSagePlusServer(Server):
|
|
def __init__(self,
|
|
ID=-1,
|
|
state=0,
|
|
config=None,
|
|
data=None,
|
|
model=None,
|
|
client_num=5,
|
|
total_round_num=10,
|
|
device='cpu',
|
|
strategy=None,
|
|
**kwargs):
|
|
r"""
|
|
FedSage+ consists of three of training stages.
|
|
Stage1: 0, local pre-train for generator.
|
|
Stage2: -> 2 * fedgen_epoch, federated training for generator.
|
|
Stage3: -> 2 * fedgen_epoch + total_round_num: federated training
|
|
for GraphSAGE Classifier
|
|
"""
|
|
super(FedSagePlusServer,
|
|
self).__init__(ID, state, config, data, model, client_num,
|
|
total_round_num, device, strategy, **kwargs)
|
|
|
|
assert self.model_num == 1, "Not supported multi-model for " \
|
|
"FedSagePlusServer"
|
|
|
|
# If state < fedgen_epoch and state % 2 == 0:
|
|
# Server receive [model, embedding, label]
|
|
# If state < fedgen_epoch and state % 2 == 1:
|
|
# Server receive [gradient]
|
|
self.fedgen_epoch = 2 * self._cfg.fedsageplus.fedgen_epoch
|
|
self.total_round_num = total_round_num + self.fedgen_epoch
|
|
self.grad_cnt = 0
|
|
|
|
def _register_default_handlers(self):
|
|
self.register_handlers('join_in', self.callback_funcs_for_join_in)
|
|
self.register_handlers('join_in_info', self.callback_funcs_for_join_in)
|
|
self.register_handlers('clf_para', self.callback_funcs_model_para)
|
|
self.register_handlers('gen_para', self.callback_funcs_model_para)
|
|
self.register_handlers('gradient', self.callback_funcs_gradient)
|
|
self.register_handlers('metrics', self.callback_funcs_for_metrics)
|
|
|
|
def callback_funcs_for_join_in(self, message: Message):
|
|
if 'info' in message.msg_type:
|
|
sender, info = message.sender, message.content
|
|
for key in self._cfg.federate.join_in_info:
|
|
assert key in info
|
|
self.join_in_info[sender] = info
|
|
logger.info('Server: Client #{:d} has joined in !'.format(sender))
|
|
else:
|
|
self.join_in_client_num += 1
|
|
sender, address = message.sender, message.content
|
|
if int(sender) == -1: # assign number to client
|
|
sender = self.join_in_client_num
|
|
self.comm_manager.add_neighbors(neighbor_id=sender,
|
|
address=address)
|
|
self.comm_manager.send(
|
|
Message(msg_type='assign_client_id',
|
|
sender=self.ID,
|
|
receiver=[sender],
|
|
state=self.state,
|
|
content=str(sender)))
|
|
else:
|
|
self.comm_manager.add_neighbors(neighbor_id=sender,
|
|
address=address)
|
|
|
|
if len(self._cfg.federate.join_in_info) != 0:
|
|
self.comm_manager.send(
|
|
Message(msg_type='ask_for_join_in_info',
|
|
sender=self.ID,
|
|
receiver=[sender],
|
|
state=self.state,
|
|
content=self._cfg.federate.join_in_info.copy()))
|
|
|
|
if self.check_client_join_in():
|
|
if self._cfg.federate.use_ss:
|
|
self.broadcast_client_address()
|
|
|
|
self.comm_manager.send(
|
|
Message(msg_type='local_pretrain',
|
|
sender=self.ID,
|
|
receiver=list(self.comm_manager.neighbors.keys()),
|
|
state=self.state))
|
|
|
|
def callback_funcs_gradient(self, message: Message):
|
|
round, _, content = message.state, message.sender, message.content
|
|
gen_grad, ID = content
|
|
# For a new round
|
|
if round not in self.msg_buffer['train'].keys():
|
|
self.msg_buffer['train'][round] = dict()
|
|
self.grad_cnt += 1
|
|
# Sum up all grad from other client
|
|
if ID not in self.msg_buffer['train'][round]:
|
|
self.msg_buffer['train'][round][ID] = dict()
|
|
for key in gen_grad.keys():
|
|
self.msg_buffer['train'][round][ID][key] = torch.FloatTensor(
|
|
gen_grad[key].cpu())
|
|
else:
|
|
for key in gen_grad.keys():
|
|
self.msg_buffer['train'][round][ID][key] += torch.FloatTensor(
|
|
gen_grad[key].cpu())
|
|
self.check_and_move_on()
|
|
|
|
def check_and_move_on(self, check_eval_result=False):
|
|
client_IDs = [i for i in range(1, self.client_num + 1)]
|
|
|
|
if check_eval_result:
|
|
# all clients are participating in evaluation
|
|
minimal_number = self.client_num
|
|
else:
|
|
# sampled clients are participating in training
|
|
minimal_number = self.sample_client_num
|
|
|
|
# Transmit model and embedding to get gradient back
|
|
if self.check_buffer(
|
|
self.state, self.client_num
|
|
) and self.state < self._cfg.fedsageplus.fedgen_epoch and self.state\
|
|
% 2 == 0:
|
|
# FedGen: we should wait for all messages
|
|
for sender in self.msg_buffer['train'][self.state]:
|
|
content = self.msg_buffer['train'][self.state][sender]
|
|
gen_para, embedding, label = content
|
|
receiver_IDs = client_IDs[:sender - 1] + client_IDs[sender:]
|
|
self.comm_manager.send(
|
|
Message(msg_type='gen_para',
|
|
sender=self.ID,
|
|
receiver=receiver_IDs,
|
|
state=self.state + 1,
|
|
content=[gen_para, embedding, label, sender]))
|
|
logger.info(f'\tServer: Transmit gen_para to'
|
|
f' {receiver_IDs} @{self.state//2}.')
|
|
self.state += 1
|
|
|
|
# Sum up gradient client-wisely and send back
|
|
if self.check_buffer(
|
|
self.state, self.client_num
|
|
) and self.state < self._cfg.fedsageplus.fedgen_epoch and self.state\
|
|
% 2 == 1 and self.grad_cnt == self.client_num * (
|
|
self.client_num - 1):
|
|
for ID in self.msg_buffer['train'][self.state]:
|
|
grad = self.msg_buffer['train'][self.state][ID]
|
|
self.comm_manager.send(
|
|
Message(msg_type='gradient',
|
|
sender=self.ID,
|
|
receiver=[ID],
|
|
state=self.state + 1,
|
|
content=grad))
|
|
# reset num of grad counter
|
|
self.grad_cnt = 0
|
|
self.state += 1
|
|
|
|
if self.check_buffer(
|
|
self.state, self.client_num
|
|
) and self.state == self._cfg.fedsageplus.fedgen_epoch:
|
|
self.state += 1
|
|
# Setup Clf_trainer for each client
|
|
self.comm_manager.send(
|
|
Message(msg_type='setup',
|
|
sender=self.ID,
|
|
receiver=list(self.comm_manager.neighbors.keys()),
|
|
state=self.state))
|
|
|
|
if self.check_buffer(
|
|
self.state, minimal_number, check_eval_result
|
|
) and self.state >= self._cfg.fedsageplus.fedgen_epoch:
|
|
|
|
if not check_eval_result: # in the training process
|
|
# Get all the message
|
|
train_msg_buffer = self.msg_buffer['train'][self.state]
|
|
msg_list = list()
|
|
for client_id in train_msg_buffer:
|
|
msg_list.append(train_msg_buffer[client_id])
|
|
|
|
# Trigger the monitor here (for training)
|
|
self._monitor.calc_model_metric(self.model.state_dict(),
|
|
msg_list,
|
|
rnd=self.state)
|
|
|
|
# Aggregate
|
|
agg_info = {
|
|
'client_feedback': msg_list,
|
|
'recover_fun': self.recover_fun
|
|
}
|
|
result = self.aggregator.aggregate(agg_info)
|
|
self.model.load_state_dict(result)
|
|
self.aggregator.update(result)
|
|
|
|
self.state += 1
|
|
if self.state % self._cfg.eval.freq == 0 and self.state != \
|
|
self.total_round_num:
|
|
# Evaluate
|
|
logger.info(
|
|
'Server : Starting evaluation at round {:d}.'.format(
|
|
self.state))
|
|
self.eval()
|
|
|
|
if self.state < self.total_round_num:
|
|
# Move to next round of training
|
|
logger.info(
|
|
f'----------- Starting a new training round(Round '
|
|
f'#{self.state}) -------------')
|
|
self.broadcast_model_para(
|
|
msg_type='model_para',
|
|
sample_client_num=self.sample_client_num)
|
|
else:
|
|
# Final Evaluate
|
|
logger.info('Server: Training is finished! Starting '
|
|
'evaluation.')
|
|
self.eval()
|
|
|
|
else: # in the evaluation process
|
|
# Get all the message & aggregate
|
|
formatted_eval_res = self.merge_eval_results_from_all_clients()
|
|
self.history_results = merge_dict_of_results(
|
|
self.history_results, formatted_eval_res)
|
|
self.check_and_save()
|
|
|
|
|
|
class FedSagePlusClient(Client):
|
|
def __init__(self,
|
|
ID=-1,
|
|
server_id=None,
|
|
state=-1,
|
|
config=None,
|
|
data=None,
|
|
model=None,
|
|
device='cpu',
|
|
strategy=None,
|
|
*args,
|
|
**kwargs):
|
|
super(FedSagePlusClient,
|
|
self).__init__(ID, server_id, state, config, data, model, device,
|
|
strategy, *args, **kwargs)
|
|
self.data = data
|
|
self.hide_data = HideGraph(self._cfg.fedsageplus.hide_portion)(
|
|
data['data'])
|
|
# Convert to `ClientData`
|
|
self.hide_data = ClientData(self._cfg,
|
|
train=[self.hide_data],
|
|
val=[self.hide_data],
|
|
test=[self.hide_data],
|
|
data=self.hide_data)
|
|
self.device = device
|
|
self.sage_batch_size = 64
|
|
self.gen = LocalSage_Plus(data['data'].x.shape[-1],
|
|
self._cfg.model.out_channels,
|
|
hidden=self._cfg.model.hidden,
|
|
gen_hidden=self._cfg.fedsageplus.gen_hidden,
|
|
dropout=self._cfg.model.dropout,
|
|
num_pred=self._cfg.fedsageplus.num_pred)
|
|
self.clf = model
|
|
self.trainer_loc = LocalGenTrainer(self.gen,
|
|
self.hide_data,
|
|
self.device,
|
|
self._cfg,
|
|
monitor=self._monitor)
|
|
|
|
self.register_handlers('clf_para', self.callback_funcs_for_model_para)
|
|
self.register_handlers('local_pretrain',
|
|
self.callback_funcs_for_local_pre_train)
|
|
self.register_handlers('gradient', self.callback_funcs_for_gradient)
|
|
self.register_handlers('gen_para', self.callback_funcs_for_gen_para)
|
|
self.register_handlers('setup', self.callback_funcs_for_setup_fedsage)
|
|
|
|
def callback_funcs_for_local_pre_train(self, message: Message):
|
|
round, sender, _ = message.state, message.sender, message.content
|
|
# Local pre-train
|
|
logger.info(f'\tClient #{self.ID} pre-train start...')
|
|
for i in range(self._cfg.fedsageplus.loc_epoch):
|
|
num_samples_train, _, _ = self.trainer_loc.train()
|
|
logger.info(f'\tClient #{self.ID} local pre-train @Epoch {i}.')
|
|
# Build fedgen base on locgen
|
|
self.fedgen = FedSage_Plus(self.gen)
|
|
# Build trainer for fedgen
|
|
self.trainer_fedgen = FedGenTrainer(self.fedgen,
|
|
self.hide_data,
|
|
self.device,
|
|
self._cfg,
|
|
monitor=self._monitor)
|
|
|
|
gen_para = self.fedgen.cpu().state_dict()
|
|
embedding = self.trainer_fedgen.embedding()
|
|
self.state = round
|
|
logger.info(f'\tClient #{self.ID} pre-train finish!')
|
|
# Start the training of fedgen
|
|
self.comm_manager.send(
|
|
Message(msg_type='gen_para',
|
|
sender=self.ID,
|
|
receiver=[sender],
|
|
state=self.state,
|
|
content=[
|
|
gen_para, embedding, self.hide_data['data'].num_missing
|
|
]))
|
|
logger.info(f'\tClient #{self.ID} send gen_para to Server #{sender}.')
|
|
|
|
def callback_funcs_for_gen_para(self, message: Message):
|
|
round, sender, content = message.state, message.sender, message.content
|
|
gen_para, embedding, label, ID = content
|
|
|
|
gen_grad = self.trainer_fedgen.cal_grad(self.data['data'], gen_para,
|
|
embedding, label)
|
|
self.state = round
|
|
self.comm_manager.send(
|
|
Message(msg_type='gradient',
|
|
sender=self.ID,
|
|
receiver=[sender],
|
|
state=self.state,
|
|
content=[gen_grad, ID]))
|
|
logger.info(f'\tClient #{self.ID}: send gradient to Server #{sender}.')
|
|
|
|
def callback_funcs_for_gradient(self, message):
|
|
# Aggregate gen_grad on server
|
|
round, sender, content = message.state, message.sender, message.content
|
|
gen_grad = content
|
|
self.trainer_fedgen.train()
|
|
gen_para = self.trainer_fedgen.update_by_grad(gen_grad)
|
|
embedding = self.trainer_fedgen.embedding()
|
|
self.state = round
|
|
self.comm_manager.send(
|
|
Message(msg_type='gen_para',
|
|
sender=self.ID,
|
|
receiver=[sender],
|
|
state=self.state,
|
|
content=[
|
|
gen_para, embedding, self.hide_data['data'].num_missing
|
|
]))
|
|
logger.info(f'\tClient #{self.ID}: send gen_para to Server #{sender}.')
|
|
|
|
def callback_funcs_for_setup_fedsage(self, message: Message):
|
|
round, sender, _ = message.state, message.sender, message.content
|
|
self.filled_data = GraphMender(
|
|
model=self.fedgen,
|
|
impaired_data=self.hide_data['data'].cpu(),
|
|
original_data=self.data['data'])
|
|
subgraph_sampler = NeighborSampler(
|
|
self.filled_data.edge_index,
|
|
sizes=[-1],
|
|
batch_size=4096,
|
|
shuffle=False,
|
|
num_workers=self._cfg.dataloader.num_workers)
|
|
fill_dataloader = {
|
|
'data': self.filled_data,
|
|
'train': NeighborSampler(
|
|
self.filled_data.edge_index,
|
|
node_idx=self.filled_data.train_idx,
|
|
sizes=self._cfg.dataloader.sizes,
|
|
batch_size=self.sage_batch_size,
|
|
shuffle=self._cfg.dataloader.shuffle,
|
|
num_workers=self._cfg.dataloader.num_workers),
|
|
'val': subgraph_sampler,
|
|
'test': subgraph_sampler
|
|
}
|
|
self._cfg.merge_from_list(
|
|
['dataloader.batch_size', self.sage_batch_size])
|
|
self.trainer_clf = NodeMiniBatchTrainer(self.clf,
|
|
fill_dataloader,
|
|
self.device,
|
|
self._cfg,
|
|
monitor=self._monitor)
|
|
sample_size, clf_para, results = self.trainer_clf.train()
|
|
self.state = round
|
|
logger.info(
|
|
self._monitor.format_eval_res(results,
|
|
rnd=self.state,
|
|
role='Client #{}'.format(self.ID)))
|
|
self.comm_manager.send(
|
|
Message(msg_type='clf_para',
|
|
sender=self.ID,
|
|
receiver=[sender],
|
|
state=self.state,
|
|
content=(sample_size, clf_para)))
|
|
|
|
def callback_funcs_for_model_para(self, message: Message):
|
|
round, sender, content = message.state, message.sender, message.content
|
|
self.trainer_clf.update(content)
|
|
self.state = round
|
|
sample_size, clf_para, results = self.trainer_clf.train()
|
|
if self._cfg.federate.share_local_model and not \
|
|
self._cfg.federate.online_aggr:
|
|
clf_para = copy.deepcopy(clf_para)
|
|
logger.info(
|
|
self._monitor.format_eval_res(results,
|
|
rnd=self.state,
|
|
role='Client #{}'.format(self.ID)))
|
|
self.comm_manager.send(
|
|
Message(msg_type='clf_para',
|
|
sender=self.ID,
|
|
receiver=[sender],
|
|
state=self.state,
|
|
content=(sample_size, clf_para)))
|