FS-TFP/federatedscope/core/parallel/parallel_runner.py

396 lines
16 KiB
Python

import logging
import time
import os
import copy
import heapq
import numpy as np
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from federatedscope.core.fed_runner import StandaloneRunner
from federatedscope.core.auxiliaries.model_builder import get_model
from federatedscope.core.auxiliaries.feat_engr_builder import \
get_feat_engr_wrapper
from federatedscope.core.auxiliaries.data_builder import get_data
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def recv_mode_para(model_para, src_rank):
for v in model_para.values():
dist.recv(tensor=v, src=src_rank)
def setup_multigpu_runner(cfg, server_class, client_class, unseen_clients_id,
server_resource_info, client_resource_info):
processes = []
mp.set_start_method("spawn")
# init parameter
client2server_queue = mp.Queue()
server2client_queues = [
mp.Queue() for _ in range(1, cfg.federate.process_num)
]
id2comm = dict()
clients_id_list = []
client_num_per_process = \
cfg.federate.client_num // (cfg.federate.process_num - 1)
for process_id in range(1, cfg.federate.process_num):
client_ids_start = (process_id - 1) * client_num_per_process + 1
client_ids_end = client_ids_start + client_num_per_process \
if process_id != cfg.federate.process_num - 1 \
else cfg.federate.client_num + 1
clients_id_list.append(range(client_ids_start, client_ids_end))
for client_id in range(client_ids_start, client_ids_end):
id2comm[client_id] = process_id - 1
# setup server process
server_rank = 0
server_process = mp.Process(
target=run,
args=(server_rank, cfg.federate.process_num, cfg.federate.master_addr,
cfg.federate.master_port,
ServerRunner(rank=server_rank,
config=cfg,
server_class=server_class,
receive_channel=client2server_queue,
send_channels=server2client_queues,
id2comm=id2comm,
unseen_clients_id=unseen_clients_id,
resource_info=server_resource_info,
client_resource_info=client_resource_info)))
server_process.start()
processes.append(server_process)
# setup client process
for rank in range(1, cfg.federate.process_num):
client_runner = ClientRunner(
rank=rank,
client_ids=clients_id_list[rank - 1],
config=cfg,
client_class=client_class,
unseen_clients_id=unseen_clients_id,
receive_channel=server2client_queues[rank - 1],
send_channel=client2server_queue,
client_resource_info=client_resource_info)
p = mp.Process(target=run,
args=(rank, cfg.federate.process_num,
cfg.federate.master_addr,
cfg.federate.master_port, client_runner))
p.start()
processes.append(p)
for p in processes:
p.join()
def run(rank, world_size, master_addr, master_port, runner):
logger.info("Process {} start to run".format(rank))
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = str(master_port)
dist.init_process_group('nccl', rank=rank, world_size=world_size)
# server process
runner.setup()
runner.run()
class StandaloneMultiGPURunner(StandaloneRunner):
def _set_up(self):
if self.cfg.backend == 'torch':
import torch
torch.set_num_threads(1)
assert self.cfg.federate.client_num != 0, \
"In standalone mode, self.cfg.federate.client_num should be " \
"non-zero. " \
"This is usually cased by using synthetic data and users not " \
"specify a non-zero value for client_num"
if self.cfg.federate.client_num < self.cfg.federate.process_num:
logger.warning('The process number is more than client number')
self.cfg.federate.process_num = self.cfg.federate.client_num
def _get_server_args(self, resource_info=None, client_resource_info=None):
if self.server_id in self.data:
server_data = self.data[self.server_id]
model = get_model(self.cfg.model,
server_data,
backend=self.cfg.backend)
else:
server_data = None
data_representative = self.data[1]
model = get_model(
self.cfg.model, data_representative, backend=self.cfg.backend
) # get the model according to client's data if the server
# does not own data
kw = {
'shared_comm_queue': self.server2client_comm_queue,
'id2comm': self.id2comm,
'resource_info': resource_info,
'client_resource_info': client_resource_info
}
return server_data, model, kw
def _get_client_args(self, client_id=-1, resource_info=None):
client_data = self.data[client_id]
kw = {
'shared_comm_queue': self.client2server_comm_queue,
'resource_info': resource_info
}
return client_data, kw
def run(self):
logger.info("Multi-GPU are starting for parallel training ...")
# sample resource information
if self.resource_info is not None:
if len(self.resource_info) < self.cfg.federate.client_num + 1:
replace = True
logger.warning(
f"Because the provided the number of resource information "
f"{len(self.resource_info)} is less than the number of "
f"participants {self.cfg.federate.client_num + 1}, one "
f"candidate might be selected multiple times.")
else:
replace = False
sampled_index = np.random.choice(
list(self.resource_info.keys()),
size=self.cfg.federate.client_num + 1,
replace=replace)
server_resource_info = self.resource_info[sampled_index[0]]
client_resource_info = [
self.resource_info[x] for x in sampled_index[1:]
]
else:
server_resource_info = None
client_resource_info = None
setup_multigpu_runner(self.cfg, self.server_class, self.client_class,
self.unseen_clients_id, server_resource_info,
client_resource_info)
class Runner(object):
def __init__(self, rank):
self.rank = rank
self.device = torch.device("cuda:{}".format(rank))
def setup(self):
raise NotImplementedError
def run(self):
raise NotImplementedError
class ServerRunner(Runner):
def __init__(self, rank, config, server_class, receive_channel,
send_channels, id2comm, unseen_clients_id, resource_info,
client_resource_info):
super().__init__(rank)
self.config = config
self.server_class = server_class
self.receive_channel = receive_channel
self.send_channel = send_channels
self.id2comm = id2comm
self.unseen_clients_id = unseen_clients_id
self.server_id = 0
self.resource_info = resource_info
self.client_resource_info = client_resource_info
self.serial_num_for_msg = 0
def setup(self):
self.config.defrost()
data, modified_cfg = get_data(config=self.config, client_cfgs=None)
self.config.merge_from_other_cfg(modified_cfg)
self.config.freeze()
if self.rank in data:
self.data = data[self.rank] if self.rank in data else data[1]
model = get_model(self.config.model,
self.data,
backend=self.config.backend)
else:
self.data = None
model = get_model(self.config.model,
data[1],
backend=self.config.backend)
kw = {
'shared_comm_queue': self.send_channel,
'id2comm': self.id2comm,
'resource_info': self.resource_info,
'client_resource_info': self.client_resource_info
}
self.server = self.server_class(
ID=self.server_id,
config=self.config,
data=self.data,
model=model,
client_num=self.config.federate.client_num,
totol_round_num=self.config.federate.total_round_num,
device=self.device,
unseen_clients_id=self.unseen_clients_id,
**kw)
self.server.model.to(self.device)
self.template_para = copy.deepcopy(self.server.model.state_dict())
if self.config.nbafl.use:
from federatedscope.core.trainers.trainer_nbafl import \
wrap_nbafl_server
wrap_nbafl_server(self.server)
logger.info('Server has been set up ... ')
_, feat_engr_wrapper_server = get_feat_engr_wrapper(self.config)
self.server = feat_engr_wrapper_server(self.server)
def run(self):
logger.info("ServerRunner {} start to run".format(self.rank))
server_msg_cache = list()
while True:
if not self.receive_channel.empty():
msg = self.receive_channel.get()
# For the server, move the received message to a
# cache for reordering the messages according to
# the timestamps
msg.serial_num = self.serial_num_for_msg
self.serial_num_for_msg += 1
heapq.heappush(server_msg_cache, msg)
elif len(server_msg_cache) > 0:
msg = heapq.heappop(server_msg_cache)
if self.config.asyn.use and self.config.asyn.aggregator \
== 'time_up':
# When the timestamp of the received message beyond
# the deadline for the currency round, trigger the
# time up event first and push the message back to
# the cache
if self.server.trigger_for_time_up(msg.timestamp):
heapq.heappush(server_msg_cache, msg)
else:
self._handle_msg(msg)
else:
self._handle_msg(msg)
else:
if self.config.asyn.use and self.config.asyn.aggregator \
== 'time_up':
self.server.trigger_for_time_up()
if self.client2server_comm_queue.empty() and \
len(server_msg_cache) == 0:
break
else:
if self.server.is_finish:
break
else:
time.sleep(0.01)
def _handle_msg(self, msg):
"""
To simulate the message handling process (used only for the
standalone mode)
"""
sender, receiver = msg.sender, msg.receiver
download_bytes, upload_bytes = msg.count_bytes()
if msg.msg_type == 'model_para':
sender_rank = self.id2comm[sender] + 1
tmp_model_para = copy.deepcopy(self.template_para)
recv_mode_para(tmp_model_para, sender_rank)
msg.content = (msg.content[0], tmp_model_para)
if not isinstance(receiver, list):
receiver = [receiver]
for each_receiver in receiver:
if each_receiver == 0:
self.server.msg_handlers[msg.msg_type](msg)
self.server._monitor.track_download_bytes(download_bytes)
else:
# should not go here
logger.warning('server received a wrong message')
class ClientRunner(Runner):
def __init__(self, rank, client_ids, config, client_class,
unseen_clients_id, receive_channel, send_channel,
client_resource_info):
super().__init__(rank)
self.client_ids = client_ids
self.config = config
self.client_class = client_class
self.unseen_clients_id = unseen_clients_id
self.base_client_id = client_ids[0]
self.receive_channel = receive_channel
self.client2server_comm_queue = send_channel
self.client_group = dict()
self.client_resource_info = client_resource_info
self.is_finish = False
def setup(self):
self.config.defrost()
self.data, modified_cfg = get_data(config=self.config,
client_cfgs=None)
self.config.merge_from_other_cfg(modified_cfg)
self.config.freeze()
self.shared_model = get_model(
self.config.model,
self.data[self.base_client_id],
backend=self.config.backend
) if self.config.federate.share_local_model else None
server_id = 0
for client_id in self.client_ids:
client_data = self.data[client_id]
kw = {
'shared_comm_queue': self.client2server_comm_queue,
'resource_info': self.client_resource_info[client_id]
if self.client_resource_info is not None else None
}
client_specific_config = self.config.clone()
if self.client_resource_info is not None:
client_specific_config.defrost()
client_specific_config.merge_from_other_cfg(
self.client_resource_info.get(
'client_{}'.format(client_id)))
client_specific_config.freeze()
client = self.client_class(
ID=client_id,
server_id=server_id,
config=client_specific_config,
data=client_data,
model=self.shared_model
or get_model(client_specific_config.model,
client_data,
backend=self.config.backend),
device=self.device,
is_unseen_client=client_id in self.unseen_clients_id,
**kw)
client.model.to(self.device)
logger.info(f'Client {client_id} has been set up ... ')
self.client_group[client_id] = client
self.template_para = copy.deepcopy(
self.client_group[self.base_client_id].model.state_dict())
def run(self):
logger.info("ClientRunner {} start to run".format(self.rank))
for _, client in self.client_group.items():
client.join_in()
while True:
if not self.receive_channel.empty():
msg = self.receive_channel.get()
self._handle_msg(msg)
elif self.is_finish:
break
def _handle_msg(self, msg):
_, receiver = msg.sender, msg.receiver
msg_type = msg.msg_type
if msg_type == 'model_para' or msg_type == 'evaluate':
# recv from server
recv_mode_para(self.template_para, 0)
msg.content = self.template_para
download_bytes, upload_bytes = msg.count_bytes()
if not isinstance(receiver, list):
receiver = [receiver]
for each_receiver in receiver:
if each_receiver in self.client_ids:
self.client_group[each_receiver].msg_handlers[msg.msg_type](
msg)
self.client_group[each_receiver]._monitor.track_download_bytes(
download_bytes)
if msg.msg_type == 'finish':
self.is_finish = True