import grpc from concurrent import futures import logging import torch.distributed as dist from collections import deque from federatedscope.core.proto import gRPC_comm_manager_pb2, \ gRPC_comm_manager_pb2_grpc from federatedscope.core.gRPC_server import gRPCComServeFunc from federatedscope.core.message import Message logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class StandaloneCommManager(object): """ The communicator used for standalone mode """ def __init__(self, comm_queue, monitor=None): self.comm_queue = comm_queue self.neighbors = dict() self.monitor = monitor # used to track the communication related # metrics def receive(self): # we don't need receive() in standalone pass def add_neighbors(self, neighbor_id, address=None): self.neighbors[neighbor_id] = address def get_neighbors(self, neighbor_id=None): address = dict() if neighbor_id: if isinstance(neighbor_id, list): for each_neighbor in neighbor_id: address[each_neighbor] = self.get_neighbors(each_neighbor) return address else: return self.neighbors[neighbor_id] else: # Get all neighbors return self.neighbors def send(self, message): # All the workers share one comm_queue self.comm_queue.append(message) class StandaloneDDPCommManager(StandaloneCommManager): """ The communicator used for standalone mode with multigpu """ def __init__(self, comm_queue, monitor=None, id2comm=None): super().__init__(comm_queue, monitor) self.id2comm = id2comm self.device = "cuda:{}".format(dist.get_rank()) def _send_model_para(self, model_para, dst_rank): for v in model_para.values(): t = v.to(self.device) dist.send(tensor=t, dst=dst_rank) def send(self, message): is_model_para = message.msg_type == 'model_para' is_evaluate = message.msg_type == 'evaluate' if self.id2comm is None: # client to server if is_model_para: model_para = message.content[1] message.content = (message.content[0], {}) self.comm_queue.append(message) if isinstance( self.comm_queue, deque) else self.comm_queue.put(message) self._send_model_para(model_para, 0) else: self.comm_queue.append(message) if isinstance( self.comm_queue, deque) else self.comm_queue.put(message) else: receiver = message.receiver if not isinstance(receiver, list): receiver = [receiver] if is_model_para or is_evaluate: model_para = message.content message.content = {} for idx, each_comm in enumerate(self.comm_queue): for each_receiver in receiver: if each_receiver in self.neighbors and \ self.id2comm[each_receiver] == idx: each_comm.put(message) break if is_model_para or is_evaluate: for each_receiver in receiver: if each_receiver in self.neighbors and \ self.id2comm[each_receiver] == idx: self._send_model_para(model_para, idx + 1) break download_bytes, upload_bytes = message.count_bytes() self.monitor.track_upload_bytes(upload_bytes) class gRPCCommManager(object): """ The implementation of gRPCCommManager is referred to the tutorial on https://grpc.io/docs/languages/python/ """ def __init__(self, host='0.0.0.0', port='50050', client_num=2, cfg=None): self.host = host self.port = port options = [ ("grpc.max_send_message_length", cfg.grpc_max_send_message_length), ("grpc.max_receive_message_length", cfg.grpc_max_receive_message_length), ("grpc.enable_http_proxy", cfg.grpc_enable_http_proxy), ] if cfg.grpc_compression.lower() == 'deflate': self.comp_method = grpc.Compression.Deflate elif cfg.grpc_compression.lower() == 'gzip': self.comp_method = grpc.Compression.Gzip else: self.comp_method = grpc.Compression.NoCompression self.server_funcs = gRPCComServeFunc() self.grpc_server = self.serve(max_workers=client_num, host=host, port=port, options=options) self.neighbors = dict() self.monitor = None # used to track the communication related metrics def serve(self, max_workers, host, port, options): """ This function is referred to https://grpc.io/docs/languages/python/basics/#starting-the-server """ server = grpc.server( futures.ThreadPoolExecutor(max_workers=max_workers), compression=self.comp_method, options=options) gRPC_comm_manager_pb2_grpc.add_gRPCComServeFuncServicer_to_server( self.server_funcs, server) server.add_insecure_port("{}:{}".format(host, port)) server.start() return server def add_neighbors(self, neighbor_id, address): if isinstance(address, dict): self.neighbors[neighbor_id] = '{}:{}'.format( address['host'], address['port']) elif isinstance(address, str): self.neighbors[neighbor_id] = address else: raise TypeError(f"The type of address ({type(address)}) is not " "supported yet") def get_neighbors(self, neighbor_id=None): address = dict() if neighbor_id: if isinstance(neighbor_id, list): for each_neighbor in neighbor_id: address[each_neighbor] = self.get_neighbors(each_neighbor) return address else: return self.neighbors[neighbor_id] else: # Get all neighbors return self.neighbors def _send(self, receiver_address, message): def _create_stub(receiver_address): """ This part is referred to https://grpc.io/docs/languages/python/basics/#creating-a-stub """ channel = grpc.insecure_channel(receiver_address, compression=self.comp_method, options=(('grpc.enable_http_proxy', 0), )) stub = gRPC_comm_manager_pb2_grpc.gRPCComServeFuncStub(channel) return stub, channel stub, channel = _create_stub(receiver_address) request = message.transform(to_list=True) try: stub.sendMessage(request) except grpc._channel._InactiveRpcError as error: logger.warning(error) pass channel.close() def send(self, message): receiver = message.receiver if receiver is not None: if not isinstance(receiver, list): receiver = [receiver] for each_receiver in receiver: if each_receiver in self.neighbors: receiver_address = self.neighbors[each_receiver] self._send(receiver_address, message) else: for each_receiver in self.neighbors: receiver_address = self.neighbors[each_receiver] self._send(receiver_address, message) def receive(self): received_msg = self.server_funcs.receive() message = Message() message.parse(received_msg.msg) return message