FS-TFP/federatedscope/core/communication.py

213 lines
7.9 KiB
Python

import grpc
from concurrent import futures
import logging
import torch.distributed as dist
from collections import deque
from federatedscope.core.proto import gRPC_comm_manager_pb2, \
gRPC_comm_manager_pb2_grpc
from federatedscope.core.gRPC_server import gRPCComServeFunc
from federatedscope.core.message import Message
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class StandaloneCommManager(object):
"""
The communicator used for standalone mode
"""
def __init__(self, comm_queue, monitor=None):
self.comm_queue = comm_queue
self.neighbors = dict()
self.monitor = monitor # used to track the communication related
# metrics
def receive(self):
# we don't need receive() in standalone
pass
def add_neighbors(self, neighbor_id, address=None):
self.neighbors[neighbor_id] = address
def get_neighbors(self, neighbor_id=None):
address = dict()
if neighbor_id:
if isinstance(neighbor_id, list):
for each_neighbor in neighbor_id:
address[each_neighbor] = self.get_neighbors(each_neighbor)
return address
else:
return self.neighbors[neighbor_id]
else:
# Get all neighbors
return self.neighbors
def send(self, message):
# All the workers share one comm_queue
self.comm_queue.append(message)
class StandaloneDDPCommManager(StandaloneCommManager):
"""
The communicator used for standalone mode with multigpu
"""
def __init__(self, comm_queue, monitor=None, id2comm=None):
super().__init__(comm_queue, monitor)
self.id2comm = id2comm
self.device = "cuda:{}".format(dist.get_rank())
def _send_model_para(self, model_para, dst_rank):
for v in model_para.values():
t = v.to(self.device)
dist.send(tensor=t, dst=dst_rank)
def send(self, message):
is_model_para = message.msg_type == 'model_para'
is_evaluate = message.msg_type == 'evaluate'
if self.id2comm is None:
# client to server
if is_model_para:
model_para = message.content[1]
message.content = (message.content[0], {})
self.comm_queue.append(message) if isinstance(
self.comm_queue, deque) else self.comm_queue.put(message)
self._send_model_para(model_para, 0)
else:
self.comm_queue.append(message) if isinstance(
self.comm_queue, deque) else self.comm_queue.put(message)
else:
receiver = message.receiver
if not isinstance(receiver, list):
receiver = [receiver]
if is_model_para or is_evaluate:
model_para = message.content
message.content = {}
for idx, each_comm in enumerate(self.comm_queue):
for each_receiver in receiver:
if each_receiver in self.neighbors and \
self.id2comm[each_receiver] == idx:
each_comm.put(message)
break
if is_model_para or is_evaluate:
for each_receiver in receiver:
if each_receiver in self.neighbors and \
self.id2comm[each_receiver] == idx:
self._send_model_para(model_para, idx + 1)
break
download_bytes, upload_bytes = message.count_bytes()
self.monitor.track_upload_bytes(upload_bytes)
class gRPCCommManager(object):
"""
The implementation of gRPCCommManager is referred to the tutorial on
https://grpc.io/docs/languages/python/
"""
def __init__(self, host='0.0.0.0', port='50050', client_num=2, cfg=None):
self.host = host
self.port = port
options = [
("grpc.max_send_message_length", cfg.grpc_max_send_message_length),
("grpc.max_receive_message_length",
cfg.grpc_max_receive_message_length),
("grpc.enable_http_proxy", cfg.grpc_enable_http_proxy),
]
if cfg.grpc_compression.lower() == 'deflate':
self.comp_method = grpc.Compression.Deflate
elif cfg.grpc_compression.lower() == 'gzip':
self.comp_method = grpc.Compression.Gzip
else:
self.comp_method = grpc.Compression.NoCompression
self.server_funcs = gRPCComServeFunc()
self.grpc_server = self.serve(max_workers=client_num,
host=host,
port=port,
options=options)
self.neighbors = dict()
self.monitor = None # used to track the communication related metrics
def serve(self, max_workers, host, port, options):
"""
This function is referred to
https://grpc.io/docs/languages/python/basics/#starting-the-server
"""
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=max_workers),
compression=self.comp_method,
options=options)
gRPC_comm_manager_pb2_grpc.add_gRPCComServeFuncServicer_to_server(
self.server_funcs, server)
server.add_insecure_port("{}:{}".format(host, port))
server.start()
return server
def add_neighbors(self, neighbor_id, address):
if isinstance(address, dict):
self.neighbors[neighbor_id] = '{}:{}'.format(
address['host'], address['port'])
elif isinstance(address, str):
self.neighbors[neighbor_id] = address
else:
raise TypeError(f"The type of address ({type(address)}) is not "
"supported yet")
def get_neighbors(self, neighbor_id=None):
address = dict()
if neighbor_id:
if isinstance(neighbor_id, list):
for each_neighbor in neighbor_id:
address[each_neighbor] = self.get_neighbors(each_neighbor)
return address
else:
return self.neighbors[neighbor_id]
else:
# Get all neighbors
return self.neighbors
def _send(self, receiver_address, message):
def _create_stub(receiver_address):
"""
This part is referred to
https://grpc.io/docs/languages/python/basics/#creating-a-stub
"""
channel = grpc.insecure_channel(receiver_address,
compression=self.comp_method,
options=(('grpc.enable_http_proxy',
0), ))
stub = gRPC_comm_manager_pb2_grpc.gRPCComServeFuncStub(channel)
return stub, channel
stub, channel = _create_stub(receiver_address)
request = message.transform(to_list=True)
try:
stub.sendMessage(request)
except grpc._channel._InactiveRpcError as error:
logger.warning(error)
pass
channel.close()
def send(self, message):
receiver = message.receiver
if receiver is not None:
if not isinstance(receiver, list):
receiver = [receiver]
for each_receiver in receiver:
if each_receiver in self.neighbors:
receiver_address = self.neighbors[each_receiver]
self._send(receiver_address, message)
else:
for each_receiver in self.neighbors:
receiver_address = self.neighbors[each_receiver]
self._send(receiver_address, message)
def receive(self):
received_msg = self.server_funcs.receive()
message = Message()
message.parse(received_msg.msg)
return message