FS-TFP/federatedscope/nlp/hetero_tasks/worker/server.py

181 lines
7.1 KiB
Python

import os
import json
import logging
import copy
import torch
import numpy as np
from federatedscope.core.message import Message
from federatedscope.core.workers import Server
from federatedscope.nlp.hetero_tasks.trainer.utils import ContrastiveMonitor
from federatedscope.nlp.hetero_tasks.dataset.utils import load_synth_data
logger = logging.getLogger(__name__)
class ATCServer(Server):
def __init__(self,
ID=-1,
state=0,
config=None,
data=None,
model=None,
client_num=5,
total_round_num=10,
device='cpu',
strategy=None,
unseen_clients_id=None,
**kwargs):
super().__init__(ID=ID,
state=state,
config=config,
data=data,
model=model,
client_num=client_num,
total_round_num=total_round_num,
device=device,
strategy=strategy,
unseen_clients_id=unseen_clients_id,
**kwargs)
# multiple models are maintained for different clients
self.models = [
copy.deepcopy(self.model) for _ in range(self.client_num)
]
self.tasks = [
config.model.pretrain_tasks[0]
if config.model.pretrain_tasks else None
for _ in range(self.client_num)
]
self.atc_vanilla = config.federate.atc_vanilla
if not self.atc_vanilla:
self.aggregator.update_models(self.models)
self.aggregator.update_neighbors(self.comm_manager.neighbors)
self.use_contrastive_loss = self._cfg.model.use_contrastive_loss
if self._cfg.model.stage == 'contrast':
# load synthetic for contrastive learning
synth_feats, synth_toks = load_synth_data(self._cfg.data)
self.contrast_monitor = ContrastiveMonitor()
self.contrast_monitor.update_enc_hidden(synth_feats)
self.contrast_monitor.update_synth_tokens(synth_toks)
self.aggregator.update_contrast_monitor(self.contrast_monitor)
def _perform_federated_aggregation(self):
train_msg_buffer = dict(
sorted(self.msg_buffer['train'][self.state].items(),
key=lambda x: x[0]))
msg_list = list()
for client_id in train_msg_buffer:
msg_list.append(train_msg_buffer[client_id])
# Aggregate
aggregated_num = len(msg_list)
if self.atc_vanilla:
agg_info = {
'client_feedback': [[x['sample_size'], x['model_para']]
for x in msg_list],
'recover_fun': self.recover_fun,
}
avg_models = self.aggregator.aggregate(agg_info)
tasks = [None for _ in range(self.client_num)]
for i in range(self.client_num):
self.models[i].load_state_dict(avg_models, strict=False)
else:
agg_info = {
'client_feedback': msg_list,
'recover_fun': self.recover_fun,
}
avg_models, tasks = self.aggregator.aggregate(agg_info)
if avg_models is not None and 'model_para' in avg_models:
for i in range(self.client_num):
self.models[i].load_state_dict(avg_models['model_para'][i],
strict=False)
self.tasks = tasks
if self.use_contrastive_loss:
if self._cfg.model.task != 'pretrain' and \
self.contrast_monitor.stat == 2:
self.msg_buffer['train'][self.state].clear()
self.broadcast_model_para(
msg_type='model_para',
sample_client_num=self.sample_client_num)
return -1
if self.contrast_monitor.stat == 3:
self.contrast_monitor.reset()
return aggregated_num
def broadcast_model_para(self,
msg_type='model_para',
sample_client_num=-1,
filter_unseen_clients=True):
if filter_unseen_clients:
self.sampler.change_state(self.unseen_clients_id, 'unseen')
if sample_client_num > 0:
sample_ids = np.random.choice(np.arange(self.client_num),
size=sample_client_num,
replace=False).tolist()
else:
sample_ids = list(range(self.client_num))
receivers = sorted(list(self.comm_manager.neighbors.keys()))
model_para = [model.state_dict() for model in self.models]
skip_broadcast = self._cfg.federate.method in ['local', 'global']
if skip_broadcast:
model_para = [{} for _ in self.models]
for i in sample_ids:
if not self.use_contrastive_loss:
content = {
'model_para': model_para[i],
'task': self.tasks[i],
}
else:
content = {
'model_para': model_para[i],
'task': self.tasks[i],
'contrast_monitor': self.contrast_monitor,
}
self.comm_manager.send(
Message(msg_type=msg_type,
sender=self.ID,
receiver=receivers[i],
state=self.state,
content=content))
if filter_unseen_clients:
self.sampler.change_state(self.unseen_clients_id, 'seen')
def merge_eval_results_from_all_clients(self, final_round=False):
state = self.state if not final_round else self.state - 1
eval_msg_buffer = self.msg_buffer['eval'][state]
if 'group_avg' in self._cfg.eval.report:
metrics_all_clients = eval_msg_buffer
else:
metrics_all_clients = dict()
for each_client in eval_msg_buffer:
client_eval_results = eval_msg_buffer[each_client]
for key in client_eval_results.keys():
res = client_eval_results[key]
if isinstance(res, dict):
for k, v in res.items():
cur_key = key + '_' + k
if key not in metrics_all_clients:
metrics_all_clients[cur_key] = list()
metrics_all_clients[cur_key].append(float(v))
else:
if key not in metrics_all_clients:
metrics_all_clients[key] = list()
metrics_all_clients[key].append(float(res))
formatted_logs = self._monitor.format_eval_res(
metrics_all_clients,
rnd=self.state + 1,
role='Server #',
forms=self._cfg.eval.report)
logger.info(formatted_logs)
self._monitor.save_formatted_results(formatted_logs)
return formatted_logs