181 lines
7.1 KiB
Python
181 lines
7.1 KiB
Python
import os
|
|
import json
|
|
import logging
|
|
import copy
|
|
import torch
|
|
import numpy as np
|
|
from federatedscope.core.message import Message
|
|
from federatedscope.core.workers import Server
|
|
from federatedscope.nlp.hetero_tasks.trainer.utils import ContrastiveMonitor
|
|
from federatedscope.nlp.hetero_tasks.dataset.utils import load_synth_data
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ATCServer(Server):
|
|
def __init__(self,
|
|
ID=-1,
|
|
state=0,
|
|
config=None,
|
|
data=None,
|
|
model=None,
|
|
client_num=5,
|
|
total_round_num=10,
|
|
device='cpu',
|
|
strategy=None,
|
|
unseen_clients_id=None,
|
|
**kwargs):
|
|
|
|
super().__init__(ID=ID,
|
|
state=state,
|
|
config=config,
|
|
data=data,
|
|
model=model,
|
|
client_num=client_num,
|
|
total_round_num=total_round_num,
|
|
device=device,
|
|
strategy=strategy,
|
|
unseen_clients_id=unseen_clients_id,
|
|
**kwargs)
|
|
|
|
# multiple models are maintained for different clients
|
|
self.models = [
|
|
copy.deepcopy(self.model) for _ in range(self.client_num)
|
|
]
|
|
self.tasks = [
|
|
config.model.pretrain_tasks[0]
|
|
if config.model.pretrain_tasks else None
|
|
for _ in range(self.client_num)
|
|
]
|
|
self.atc_vanilla = config.federate.atc_vanilla
|
|
if not self.atc_vanilla:
|
|
self.aggregator.update_models(self.models)
|
|
self.aggregator.update_neighbors(self.comm_manager.neighbors)
|
|
|
|
self.use_contrastive_loss = self._cfg.model.use_contrastive_loss
|
|
if self._cfg.model.stage == 'contrast':
|
|
# load synthetic for contrastive learning
|
|
synth_feats, synth_toks = load_synth_data(self._cfg.data)
|
|
self.contrast_monitor = ContrastiveMonitor()
|
|
self.contrast_monitor.update_enc_hidden(synth_feats)
|
|
self.contrast_monitor.update_synth_tokens(synth_toks)
|
|
self.aggregator.update_contrast_monitor(self.contrast_monitor)
|
|
|
|
def _perform_federated_aggregation(self):
|
|
train_msg_buffer = dict(
|
|
sorted(self.msg_buffer['train'][self.state].items(),
|
|
key=lambda x: x[0]))
|
|
msg_list = list()
|
|
for client_id in train_msg_buffer:
|
|
msg_list.append(train_msg_buffer[client_id])
|
|
|
|
# Aggregate
|
|
aggregated_num = len(msg_list)
|
|
if self.atc_vanilla:
|
|
agg_info = {
|
|
'client_feedback': [[x['sample_size'], x['model_para']]
|
|
for x in msg_list],
|
|
'recover_fun': self.recover_fun,
|
|
}
|
|
avg_models = self.aggregator.aggregate(agg_info)
|
|
tasks = [None for _ in range(self.client_num)]
|
|
for i in range(self.client_num):
|
|
self.models[i].load_state_dict(avg_models, strict=False)
|
|
else:
|
|
agg_info = {
|
|
'client_feedback': msg_list,
|
|
'recover_fun': self.recover_fun,
|
|
}
|
|
avg_models, tasks = self.aggregator.aggregate(agg_info)
|
|
if avg_models is not None and 'model_para' in avg_models:
|
|
for i in range(self.client_num):
|
|
self.models[i].load_state_dict(avg_models['model_para'][i],
|
|
strict=False)
|
|
self.tasks = tasks
|
|
|
|
if self.use_contrastive_loss:
|
|
if self._cfg.model.task != 'pretrain' and \
|
|
self.contrast_monitor.stat == 2:
|
|
self.msg_buffer['train'][self.state].clear()
|
|
self.broadcast_model_para(
|
|
msg_type='model_para',
|
|
sample_client_num=self.sample_client_num)
|
|
return -1
|
|
if self.contrast_monitor.stat == 3:
|
|
self.contrast_monitor.reset()
|
|
|
|
return aggregated_num
|
|
|
|
def broadcast_model_para(self,
|
|
msg_type='model_para',
|
|
sample_client_num=-1,
|
|
filter_unseen_clients=True):
|
|
if filter_unseen_clients:
|
|
self.sampler.change_state(self.unseen_clients_id, 'unseen')
|
|
|
|
if sample_client_num > 0:
|
|
sample_ids = np.random.choice(np.arange(self.client_num),
|
|
size=sample_client_num,
|
|
replace=False).tolist()
|
|
else:
|
|
sample_ids = list(range(self.client_num))
|
|
|
|
receivers = sorted(list(self.comm_manager.neighbors.keys()))
|
|
model_para = [model.state_dict() for model in self.models]
|
|
skip_broadcast = self._cfg.federate.method in ['local', 'global']
|
|
if skip_broadcast:
|
|
model_para = [{} for _ in self.models]
|
|
|
|
for i in sample_ids:
|
|
if not self.use_contrastive_loss:
|
|
content = {
|
|
'model_para': model_para[i],
|
|
'task': self.tasks[i],
|
|
}
|
|
else:
|
|
content = {
|
|
'model_para': model_para[i],
|
|
'task': self.tasks[i],
|
|
'contrast_monitor': self.contrast_monitor,
|
|
}
|
|
self.comm_manager.send(
|
|
Message(msg_type=msg_type,
|
|
sender=self.ID,
|
|
receiver=receivers[i],
|
|
state=self.state,
|
|
content=content))
|
|
|
|
if filter_unseen_clients:
|
|
self.sampler.change_state(self.unseen_clients_id, 'seen')
|
|
|
|
def merge_eval_results_from_all_clients(self, final_round=False):
|
|
state = self.state if not final_round else self.state - 1
|
|
eval_msg_buffer = self.msg_buffer['eval'][state]
|
|
|
|
if 'group_avg' in self._cfg.eval.report:
|
|
metrics_all_clients = eval_msg_buffer
|
|
else:
|
|
metrics_all_clients = dict()
|
|
for each_client in eval_msg_buffer:
|
|
client_eval_results = eval_msg_buffer[each_client]
|
|
for key in client_eval_results.keys():
|
|
res = client_eval_results[key]
|
|
if isinstance(res, dict):
|
|
for k, v in res.items():
|
|
cur_key = key + '_' + k
|
|
if key not in metrics_all_clients:
|
|
metrics_all_clients[cur_key] = list()
|
|
metrics_all_clients[cur_key].append(float(v))
|
|
else:
|
|
if key not in metrics_all_clients:
|
|
metrics_all_clients[key] = list()
|
|
metrics_all_clients[key].append(float(res))
|
|
formatted_logs = self._monitor.format_eval_res(
|
|
metrics_all_clients,
|
|
rnd=self.state + 1,
|
|
role='Server #',
|
|
forms=self._cfg.eval.report)
|
|
logger.info(formatted_logs)
|
|
self._monitor.save_formatted_results(formatted_logs)
|
|
return formatted_logs
|