import json import pickle import base64 import numpy as np from federatedscope.core.proto import gRPC_comm_manager_pb2 def b64serializer(x): return base64.b64encode(pickle.dumps(x)) class Message(object): """ The data exchanged during an FL course are abstracted as 'Message' in FederatedScope. A message object includes: msg_type: The type of message, which is used to trigger the corresponding handlers of server/client sender: The sender's ID receiver: The receiver's ID state: The training round of the message, which is determined by the sender and used to filter out the outdated messages. strategy: redundant attribute """ def __init__(self, msg_type=None, sender=0, receiver=0, state=0, content='None', timestamp=0, strategy=None, serial_num=0): self._msg_type = msg_type self._sender = sender self._receiver = receiver self._state = state self._content = content self._timestamp = timestamp self._strategy = strategy self.serial_num = serial_num self.param_serializer = b64serializer @property def msg_type(self): return self._msg_type @msg_type.setter def msg_type(self, value): self._msg_type = value @property def sender(self): return self._sender @sender.setter def sender(self, value): self._sender = value @property def receiver(self): return self._receiver @receiver.setter def receiver(self, value): self._receiver = value @property def state(self): return self._state @state.setter def state(self, value): self._state = value @property def content(self): return self._content @content.setter def content(self, value): self._content = value @property def timestamp(self): return self._timestamp @timestamp.setter def timestamp(self, value): assert isinstance(value, int) or isinstance(value, float), \ "We only support an int or a float value for timestamp" self._timestamp = value @property def strategy(self): return self._strategy @strategy.setter def strategy(self, value): self._strategy = value def __lt__(self, other): if self.timestamp != other.timestamp: return self.timestamp < other.timestamp elif self.state != other.state: return self.state < other.state else: return self.serial_num < other.serial_num def transform_to_list(self, x): if isinstance(x, list) or isinstance(x, tuple): return [self.transform_to_list(each_x) for each_x in x] elif isinstance(x, dict): for key in x.keys(): x[key] = self.transform_to_list(x[key]) return x else: if hasattr(x, 'tolist'): if self.msg_type == 'model_para': return self.param_serializer(x) else: return x.tolist() else: return x def msg_to_json(self, to_list=False): if to_list: self.content = self.transform_to_list(self.content) json_msg = { 'msg_type': self.msg_type, 'sender': self.sender, 'receiver': self.receiver, 'state': self.state, 'content': self.content, 'timestamp': self.timestamp, 'strategy': self.strategy, } return json.dumps(json_msg) def json_to_msg(self, json_string): json_msg = json.loads(json_string) self.msg_type = json_msg['msg_type'] self.sender = json_msg['sender'] self.receiver = json_msg['receiver'] self.state = json_msg['state'] self.content = json_msg['content'] self.timestamp = json_msg['timestamp'] self.strategy = json_msg['strategy'] def create_by_type(self, value, nested=False): if isinstance(value, dict): if isinstance(list(value.keys())[0], str): m_dict = gRPC_comm_manager_pb2.mDict_keyIsString() key_type = 'string' else: m_dict = gRPC_comm_manager_pb2.mDict_keyIsInt() key_type = 'int' for key in value.keys(): m_dict.dict_value[key].MergeFrom( self.create_by_type(value[key], nested=True)) if nested: msg_value = gRPC_comm_manager_pb2.MsgValue() if key_type == 'string': msg_value.dict_msg_stringkey.MergeFrom(m_dict) else: msg_value.dict_msg_intkey.MergeFrom(m_dict) return msg_value else: return m_dict elif isinstance(value, list) or isinstance(value, tuple): m_list = gRPC_comm_manager_pb2.mList() for each in value: m_list.list_value.append(self.create_by_type(each, nested=True)) if nested: msg_value = gRPC_comm_manager_pb2.MsgValue() msg_value.list_msg.MergeFrom(m_list) return msg_value else: return m_list else: m_single = gRPC_comm_manager_pb2.mSingle() if type(value) in [int, np.int32]: m_single.int_value = value elif type(value) in [str, bytes]: m_single.str_value = value elif type(value) in [float, np.float32]: m_single.float_value = value else: raise ValueError( 'The data type {} has not been supported.'.format( type(value))) if nested: msg_value = gRPC_comm_manager_pb2.MsgValue() msg_value.single_msg.MergeFrom(m_single) return msg_value else: return m_single def build_msg_value(self, value): msg_value = gRPC_comm_manager_pb2.MsgValue() if isinstance(value, list) or isinstance(value, tuple): msg_value.list_msg.MergeFrom(self.create_by_type(value)) elif isinstance(value, dict): if isinstance(list(value.keys())[0], str): msg_value.dict_msg_stringkey.MergeFrom( self.create_by_type(value)) else: msg_value.dict_msg_intkey.MergeFrom(self.create_by_type(value)) else: msg_value.single_msg.MergeFrom(self.create_by_type(value)) return msg_value def transform(self, to_list=False): if to_list: self.content = self.transform_to_list(self.content) splited_msg = gRPC_comm_manager_pb2.MessageRequest() # map/dict splited_msg.msg['sender'].MergeFrom(self.build_msg_value(self.sender)) splited_msg.msg['receiver'].MergeFrom( self.build_msg_value(self.receiver)) splited_msg.msg['state'].MergeFrom(self.build_msg_value(self.state)) splited_msg.msg['msg_type'].MergeFrom( self.build_msg_value(self.msg_type)) splited_msg.msg['content'].MergeFrom(self.build_msg_value( self.content)) splited_msg.msg['timestamp'].MergeFrom( self.build_msg_value(self.timestamp)) return splited_msg def _parse_msg(self, value): if isinstance(value, gRPC_comm_manager_pb2.MsgValue) or isinstance( value, gRPC_comm_manager_pb2.mSingle): return self._parse_msg(getattr(value, value.WhichOneof("type"))) elif isinstance(value, gRPC_comm_manager_pb2.mList): return [self._parse_msg(each) for each in value.list_value] elif isinstance(value, gRPC_comm_manager_pb2.mDict_keyIsString) or \ isinstance(value, gRPC_comm_manager_pb2.mDict_keyIsInt): return { k: self._parse_msg(value.dict_value[k]) for k in value.dict_value } else: return value def parse(self, received_msg): self.sender = self._parse_msg(received_msg['sender']) self.receiver = self._parse_msg(received_msg['receiver']) self.msg_type = self._parse_msg(received_msg['msg_type']) self.state = self._parse_msg(received_msg['state']) self.content = self._parse_msg(received_msg['content']) self.timestamp = self._parse_msg(received_msg['timestamp']) def count_bytes(self): """ calculate the message bytes to be sent/received :return: tuple of bytes of the message to be sent and received """ from pympler import asizeof download_bytes = asizeof.asizeof(self.content) upload_cnt = len(self.receiver) if isinstance(self.receiver, list) else 1 upload_bytes = download_bytes * upload_cnt return download_bytes, upload_bytes