396 lines
16 KiB
Python
396 lines
16 KiB
Python
import logging
|
|
import time
|
|
import os
|
|
import copy
|
|
import heapq
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.multiprocessing as mp
|
|
import torch.distributed as dist
|
|
|
|
from federatedscope.core.fed_runner import StandaloneRunner
|
|
from federatedscope.core.auxiliaries.model_builder import get_model
|
|
from federatedscope.core.auxiliaries.feat_engr_builder import \
|
|
get_feat_engr_wrapper
|
|
from federatedscope.core.auxiliaries.data_builder import get_data
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
def recv_mode_para(model_para, src_rank):
|
|
for v in model_para.values():
|
|
dist.recv(tensor=v, src=src_rank)
|
|
|
|
|
|
def setup_multigpu_runner(cfg, server_class, client_class, unseen_clients_id,
|
|
server_resource_info, client_resource_info):
|
|
processes = []
|
|
mp.set_start_method("spawn")
|
|
|
|
# init parameter
|
|
client2server_queue = mp.Queue()
|
|
server2client_queues = [
|
|
mp.Queue() for _ in range(1, cfg.federate.process_num)
|
|
]
|
|
id2comm = dict()
|
|
clients_id_list = []
|
|
client_num_per_process = \
|
|
cfg.federate.client_num // (cfg.federate.process_num - 1)
|
|
for process_id in range(1, cfg.federate.process_num):
|
|
client_ids_start = (process_id - 1) * client_num_per_process + 1
|
|
client_ids_end = client_ids_start + client_num_per_process \
|
|
if process_id != cfg.federate.process_num - 1 \
|
|
else cfg.federate.client_num + 1
|
|
clients_id_list.append(range(client_ids_start, client_ids_end))
|
|
for client_id in range(client_ids_start, client_ids_end):
|
|
id2comm[client_id] = process_id - 1
|
|
|
|
# setup server process
|
|
server_rank = 0
|
|
server_process = mp.Process(
|
|
target=run,
|
|
args=(server_rank, cfg.federate.process_num, cfg.federate.master_addr,
|
|
cfg.federate.master_port,
|
|
ServerRunner(rank=server_rank,
|
|
config=cfg,
|
|
server_class=server_class,
|
|
receive_channel=client2server_queue,
|
|
send_channels=server2client_queues,
|
|
id2comm=id2comm,
|
|
unseen_clients_id=unseen_clients_id,
|
|
resource_info=server_resource_info,
|
|
client_resource_info=client_resource_info)))
|
|
server_process.start()
|
|
processes.append(server_process)
|
|
|
|
# setup client process
|
|
for rank in range(1, cfg.federate.process_num):
|
|
client_runner = ClientRunner(
|
|
rank=rank,
|
|
client_ids=clients_id_list[rank - 1],
|
|
config=cfg,
|
|
client_class=client_class,
|
|
unseen_clients_id=unseen_clients_id,
|
|
receive_channel=server2client_queues[rank - 1],
|
|
send_channel=client2server_queue,
|
|
client_resource_info=client_resource_info)
|
|
p = mp.Process(target=run,
|
|
args=(rank, cfg.federate.process_num,
|
|
cfg.federate.master_addr,
|
|
cfg.federate.master_port, client_runner))
|
|
p.start()
|
|
processes.append(p)
|
|
for p in processes:
|
|
p.join()
|
|
|
|
|
|
def run(rank, world_size, master_addr, master_port, runner):
|
|
logger.info("Process {} start to run".format(rank))
|
|
os.environ['MASTER_ADDR'] = master_addr
|
|
os.environ['MASTER_PORT'] = str(master_port)
|
|
dist.init_process_group('nccl', rank=rank, world_size=world_size)
|
|
# server process
|
|
runner.setup()
|
|
runner.run()
|
|
|
|
|
|
class StandaloneMultiGPURunner(StandaloneRunner):
|
|
def _set_up(self):
|
|
if self.cfg.backend == 'torch':
|
|
import torch
|
|
torch.set_num_threads(1)
|
|
assert self.cfg.federate.client_num != 0, \
|
|
"In standalone mode, self.cfg.federate.client_num should be " \
|
|
"non-zero. " \
|
|
"This is usually cased by using synthetic data and users not " \
|
|
"specify a non-zero value for client_num"
|
|
|
|
if self.cfg.federate.client_num < self.cfg.federate.process_num:
|
|
logger.warning('The process number is more than client number')
|
|
self.cfg.federate.process_num = self.cfg.federate.client_num
|
|
|
|
def _get_server_args(self, resource_info=None, client_resource_info=None):
|
|
if self.server_id in self.data:
|
|
server_data = self.data[self.server_id]
|
|
model = get_model(self.cfg.model,
|
|
server_data,
|
|
backend=self.cfg.backend)
|
|
else:
|
|
server_data = None
|
|
data_representative = self.data[1]
|
|
model = get_model(
|
|
self.cfg.model, data_representative, backend=self.cfg.backend
|
|
) # get the model according to client's data if the server
|
|
# does not own data
|
|
kw = {
|
|
'shared_comm_queue': self.server2client_comm_queue,
|
|
'id2comm': self.id2comm,
|
|
'resource_info': resource_info,
|
|
'client_resource_info': client_resource_info
|
|
}
|
|
return server_data, model, kw
|
|
|
|
def _get_client_args(self, client_id=-1, resource_info=None):
|
|
client_data = self.data[client_id]
|
|
kw = {
|
|
'shared_comm_queue': self.client2server_comm_queue,
|
|
'resource_info': resource_info
|
|
}
|
|
return client_data, kw
|
|
|
|
def run(self):
|
|
logger.info("Multi-GPU are starting for parallel training ...")
|
|
# sample resource information
|
|
if self.resource_info is not None:
|
|
if len(self.resource_info) < self.cfg.federate.client_num + 1:
|
|
replace = True
|
|
logger.warning(
|
|
f"Because the provided the number of resource information "
|
|
f"{len(self.resource_info)} is less than the number of "
|
|
f"participants {self.cfg.federate.client_num + 1}, one "
|
|
f"candidate might be selected multiple times.")
|
|
else:
|
|
replace = False
|
|
sampled_index = np.random.choice(
|
|
list(self.resource_info.keys()),
|
|
size=self.cfg.federate.client_num + 1,
|
|
replace=replace)
|
|
server_resource_info = self.resource_info[sampled_index[0]]
|
|
client_resource_info = [
|
|
self.resource_info[x] for x in sampled_index[1:]
|
|
]
|
|
else:
|
|
server_resource_info = None
|
|
client_resource_info = None
|
|
setup_multigpu_runner(self.cfg, self.server_class, self.client_class,
|
|
self.unseen_clients_id, server_resource_info,
|
|
client_resource_info)
|
|
|
|
|
|
class Runner(object):
|
|
def __init__(self, rank):
|
|
self.rank = rank
|
|
self.device = torch.device("cuda:{}".format(rank))
|
|
|
|
def setup(self):
|
|
raise NotImplementedError
|
|
|
|
def run(self):
|
|
raise NotImplementedError
|
|
|
|
|
|
class ServerRunner(Runner):
|
|
def __init__(self, rank, config, server_class, receive_channel,
|
|
send_channels, id2comm, unseen_clients_id, resource_info,
|
|
client_resource_info):
|
|
super().__init__(rank)
|
|
self.config = config
|
|
self.server_class = server_class
|
|
self.receive_channel = receive_channel
|
|
self.send_channel = send_channels
|
|
self.id2comm = id2comm
|
|
self.unseen_clients_id = unseen_clients_id
|
|
self.server_id = 0
|
|
self.resource_info = resource_info
|
|
self.client_resource_info = client_resource_info
|
|
self.serial_num_for_msg = 0
|
|
|
|
def setup(self):
|
|
self.config.defrost()
|
|
data, modified_cfg = get_data(config=self.config, client_cfgs=None)
|
|
self.config.merge_from_other_cfg(modified_cfg)
|
|
self.config.freeze()
|
|
if self.rank in data:
|
|
self.data = data[self.rank] if self.rank in data else data[1]
|
|
model = get_model(self.config.model,
|
|
self.data,
|
|
backend=self.config.backend)
|
|
else:
|
|
self.data = None
|
|
model = get_model(self.config.model,
|
|
data[1],
|
|
backend=self.config.backend)
|
|
kw = {
|
|
'shared_comm_queue': self.send_channel,
|
|
'id2comm': self.id2comm,
|
|
'resource_info': self.resource_info,
|
|
'client_resource_info': self.client_resource_info
|
|
}
|
|
|
|
self.server = self.server_class(
|
|
ID=self.server_id,
|
|
config=self.config,
|
|
data=self.data,
|
|
model=model,
|
|
client_num=self.config.federate.client_num,
|
|
totol_round_num=self.config.federate.total_round_num,
|
|
device=self.device,
|
|
unseen_clients_id=self.unseen_clients_id,
|
|
**kw)
|
|
|
|
self.server.model.to(self.device)
|
|
self.template_para = copy.deepcopy(self.server.model.state_dict())
|
|
if self.config.nbafl.use:
|
|
from federatedscope.core.trainers.trainer_nbafl import \
|
|
wrap_nbafl_server
|
|
wrap_nbafl_server(self.server)
|
|
logger.info('Server has been set up ... ')
|
|
_, feat_engr_wrapper_server = get_feat_engr_wrapper(self.config)
|
|
self.server = feat_engr_wrapper_server(self.server)
|
|
|
|
def run(self):
|
|
logger.info("ServerRunner {} start to run".format(self.rank))
|
|
server_msg_cache = list()
|
|
while True:
|
|
if not self.receive_channel.empty():
|
|
msg = self.receive_channel.get()
|
|
# For the server, move the received message to a
|
|
# cache for reordering the messages according to
|
|
# the timestamps
|
|
msg.serial_num = self.serial_num_for_msg
|
|
self.serial_num_for_msg += 1
|
|
heapq.heappush(server_msg_cache, msg)
|
|
elif len(server_msg_cache) > 0:
|
|
msg = heapq.heappop(server_msg_cache)
|
|
if self.config.asyn.use and self.config.asyn.aggregator \
|
|
== 'time_up':
|
|
# When the timestamp of the received message beyond
|
|
# the deadline for the currency round, trigger the
|
|
# time up event first and push the message back to
|
|
# the cache
|
|
if self.server.trigger_for_time_up(msg.timestamp):
|
|
heapq.heappush(server_msg_cache, msg)
|
|
else:
|
|
self._handle_msg(msg)
|
|
else:
|
|
self._handle_msg(msg)
|
|
else:
|
|
if self.config.asyn.use and self.config.asyn.aggregator \
|
|
== 'time_up':
|
|
self.server.trigger_for_time_up()
|
|
if self.client2server_comm_queue.empty() and \
|
|
len(server_msg_cache) == 0:
|
|
break
|
|
else:
|
|
if self.server.is_finish:
|
|
break
|
|
else:
|
|
time.sleep(0.01)
|
|
|
|
def _handle_msg(self, msg):
|
|
"""
|
|
To simulate the message handling process (used only for the
|
|
standalone mode)
|
|
"""
|
|
sender, receiver = msg.sender, msg.receiver
|
|
download_bytes, upload_bytes = msg.count_bytes()
|
|
if msg.msg_type == 'model_para':
|
|
sender_rank = self.id2comm[sender] + 1
|
|
tmp_model_para = copy.deepcopy(self.template_para)
|
|
recv_mode_para(tmp_model_para, sender_rank)
|
|
msg.content = (msg.content[0], tmp_model_para)
|
|
if not isinstance(receiver, list):
|
|
receiver = [receiver]
|
|
for each_receiver in receiver:
|
|
if each_receiver == 0:
|
|
self.server.msg_handlers[msg.msg_type](msg)
|
|
self.server._monitor.track_download_bytes(download_bytes)
|
|
else:
|
|
# should not go here
|
|
logger.warning('server received a wrong message')
|
|
|
|
|
|
class ClientRunner(Runner):
|
|
def __init__(self, rank, client_ids, config, client_class,
|
|
unseen_clients_id, receive_channel, send_channel,
|
|
client_resource_info):
|
|
super().__init__(rank)
|
|
self.client_ids = client_ids
|
|
self.config = config
|
|
self.client_class = client_class
|
|
self.unseen_clients_id = unseen_clients_id
|
|
self.base_client_id = client_ids[0]
|
|
self.receive_channel = receive_channel
|
|
self.client2server_comm_queue = send_channel
|
|
self.client_group = dict()
|
|
self.client_resource_info = client_resource_info
|
|
self.is_finish = False
|
|
|
|
def setup(self):
|
|
self.config.defrost()
|
|
self.data, modified_cfg = get_data(config=self.config,
|
|
client_cfgs=None)
|
|
self.config.merge_from_other_cfg(modified_cfg)
|
|
self.config.freeze()
|
|
self.shared_model = get_model(
|
|
self.config.model,
|
|
self.data[self.base_client_id],
|
|
backend=self.config.backend
|
|
) if self.config.federate.share_local_model else None
|
|
|
|
server_id = 0
|
|
|
|
for client_id in self.client_ids:
|
|
client_data = self.data[client_id]
|
|
kw = {
|
|
'shared_comm_queue': self.client2server_comm_queue,
|
|
'resource_info': self.client_resource_info[client_id]
|
|
if self.client_resource_info is not None else None
|
|
}
|
|
client_specific_config = self.config.clone()
|
|
if self.client_resource_info is not None:
|
|
client_specific_config.defrost()
|
|
client_specific_config.merge_from_other_cfg(
|
|
self.client_resource_info.get(
|
|
'client_{}'.format(client_id)))
|
|
client_specific_config.freeze()
|
|
client = self.client_class(
|
|
ID=client_id,
|
|
server_id=server_id,
|
|
config=client_specific_config,
|
|
data=client_data,
|
|
model=self.shared_model
|
|
or get_model(client_specific_config.model,
|
|
client_data,
|
|
backend=self.config.backend),
|
|
device=self.device,
|
|
is_unseen_client=client_id in self.unseen_clients_id,
|
|
**kw)
|
|
client.model.to(self.device)
|
|
logger.info(f'Client {client_id} has been set up ... ')
|
|
self.client_group[client_id] = client
|
|
self.template_para = copy.deepcopy(
|
|
self.client_group[self.base_client_id].model.state_dict())
|
|
|
|
def run(self):
|
|
logger.info("ClientRunner {} start to run".format(self.rank))
|
|
for _, client in self.client_group.items():
|
|
client.join_in()
|
|
while True:
|
|
if not self.receive_channel.empty():
|
|
msg = self.receive_channel.get()
|
|
self._handle_msg(msg)
|
|
elif self.is_finish:
|
|
break
|
|
|
|
def _handle_msg(self, msg):
|
|
_, receiver = msg.sender, msg.receiver
|
|
msg_type = msg.msg_type
|
|
if msg_type == 'model_para' or msg_type == 'evaluate':
|
|
# recv from server
|
|
recv_mode_para(self.template_para, 0)
|
|
msg.content = self.template_para
|
|
download_bytes, upload_bytes = msg.count_bytes()
|
|
if not isinstance(receiver, list):
|
|
receiver = [receiver]
|
|
for each_receiver in receiver:
|
|
if each_receiver in self.client_ids:
|
|
self.client_group[each_receiver].msg_handlers[msg.msg_type](
|
|
msg)
|
|
self.client_group[each_receiver]._monitor.track_download_bytes(
|
|
download_bytes)
|
|
if msg.msg_type == 'finish':
|
|
self.is_finish = True
|