FS-TFP/federatedscope/core/message.py

270 lines
9.1 KiB
Python

import json
import pickle
import base64
import numpy as np
from federatedscope.core.proto import gRPC_comm_manager_pb2
def b64serializer(x):
return base64.b64encode(pickle.dumps(x))
class Message(object):
"""
The data exchanged during an FL course are abstracted as 'Message' in
FederatedScope.
A message object includes:
msg_type: The type of message, which is used to trigger the
corresponding handlers of server/client
sender: The sender's ID
receiver: The receiver's ID
state: The training round of the message, which is determined by
the sender and used to filter out the outdated messages.
strategy: redundant attribute
"""
def __init__(self,
msg_type=None,
sender=0,
receiver=0,
state=0,
content='None',
timestamp=0,
strategy=None,
serial_num=0):
self._msg_type = msg_type
self._sender = sender
self._receiver = receiver
self._state = state
self._content = content
self._timestamp = timestamp
self._strategy = strategy
self.serial_num = serial_num
self.param_serializer = b64serializer
@property
def msg_type(self):
return self._msg_type
@msg_type.setter
def msg_type(self, value):
self._msg_type = value
@property
def sender(self):
return self._sender
@sender.setter
def sender(self, value):
self._sender = value
@property
def receiver(self):
return self._receiver
@receiver.setter
def receiver(self, value):
self._receiver = value
@property
def state(self):
return self._state
@state.setter
def state(self, value):
self._state = value
@property
def content(self):
return self._content
@content.setter
def content(self, value):
self._content = value
@property
def timestamp(self):
return self._timestamp
@timestamp.setter
def timestamp(self, value):
assert isinstance(value, int) or isinstance(value, float), \
"We only support an int or a float value for timestamp"
self._timestamp = value
@property
def strategy(self):
return self._strategy
@strategy.setter
def strategy(self, value):
self._strategy = value
def __lt__(self, other):
if self.timestamp != other.timestamp:
return self.timestamp < other.timestamp
elif self.state != other.state:
return self.state < other.state
else:
return self.serial_num < other.serial_num
def transform_to_list(self, x):
if isinstance(x, list) or isinstance(x, tuple):
return [self.transform_to_list(each_x) for each_x in x]
elif isinstance(x, dict):
for key in x.keys():
x[key] = self.transform_to_list(x[key])
return x
else:
if hasattr(x, 'tolist'):
if self.msg_type == 'model_para':
return self.param_serializer(x)
else:
return x.tolist()
else:
return x
def msg_to_json(self, to_list=False):
if to_list:
self.content = self.transform_to_list(self.content)
json_msg = {
'msg_type': self.msg_type,
'sender': self.sender,
'receiver': self.receiver,
'state': self.state,
'content': self.content,
'timestamp': self.timestamp,
'strategy': self.strategy,
}
return json.dumps(json_msg)
def json_to_msg(self, json_string):
json_msg = json.loads(json_string)
self.msg_type = json_msg['msg_type']
self.sender = json_msg['sender']
self.receiver = json_msg['receiver']
self.state = json_msg['state']
self.content = json_msg['content']
self.timestamp = json_msg['timestamp']
self.strategy = json_msg['strategy']
def create_by_type(self, value, nested=False):
if isinstance(value, dict):
if isinstance(list(value.keys())[0], str):
m_dict = gRPC_comm_manager_pb2.mDict_keyIsString()
key_type = 'string'
else:
m_dict = gRPC_comm_manager_pb2.mDict_keyIsInt()
key_type = 'int'
for key in value.keys():
m_dict.dict_value[key].MergeFrom(
self.create_by_type(value[key], nested=True))
if nested:
msg_value = gRPC_comm_manager_pb2.MsgValue()
if key_type == 'string':
msg_value.dict_msg_stringkey.MergeFrom(m_dict)
else:
msg_value.dict_msg_intkey.MergeFrom(m_dict)
return msg_value
else:
return m_dict
elif isinstance(value, list) or isinstance(value, tuple):
m_list = gRPC_comm_manager_pb2.mList()
for each in value:
m_list.list_value.append(self.create_by_type(each,
nested=True))
if nested:
msg_value = gRPC_comm_manager_pb2.MsgValue()
msg_value.list_msg.MergeFrom(m_list)
return msg_value
else:
return m_list
else:
m_single = gRPC_comm_manager_pb2.mSingle()
if type(value) in [int, np.int32]:
m_single.int_value = value
elif type(value) in [str, bytes]:
m_single.str_value = value
elif type(value) in [float, np.float32]:
m_single.float_value = value
else:
raise ValueError(
'The data type {} has not been supported.'.format(
type(value)))
if nested:
msg_value = gRPC_comm_manager_pb2.MsgValue()
msg_value.single_msg.MergeFrom(m_single)
return msg_value
else:
return m_single
def build_msg_value(self, value):
msg_value = gRPC_comm_manager_pb2.MsgValue()
if isinstance(value, list) or isinstance(value, tuple):
msg_value.list_msg.MergeFrom(self.create_by_type(value))
elif isinstance(value, dict):
if isinstance(list(value.keys())[0], str):
msg_value.dict_msg_stringkey.MergeFrom(
self.create_by_type(value))
else:
msg_value.dict_msg_intkey.MergeFrom(self.create_by_type(value))
else:
msg_value.single_msg.MergeFrom(self.create_by_type(value))
return msg_value
def transform(self, to_list=False):
if to_list:
self.content = self.transform_to_list(self.content)
splited_msg = gRPC_comm_manager_pb2.MessageRequest() # map/dict
splited_msg.msg['sender'].MergeFrom(self.build_msg_value(self.sender))
splited_msg.msg['receiver'].MergeFrom(
self.build_msg_value(self.receiver))
splited_msg.msg['state'].MergeFrom(self.build_msg_value(self.state))
splited_msg.msg['msg_type'].MergeFrom(
self.build_msg_value(self.msg_type))
splited_msg.msg['content'].MergeFrom(self.build_msg_value(
self.content))
splited_msg.msg['timestamp'].MergeFrom(
self.build_msg_value(self.timestamp))
return splited_msg
def _parse_msg(self, value):
if isinstance(value, gRPC_comm_manager_pb2.MsgValue) or isinstance(
value, gRPC_comm_manager_pb2.mSingle):
return self._parse_msg(getattr(value, value.WhichOneof("type")))
elif isinstance(value, gRPC_comm_manager_pb2.mList):
return [self._parse_msg(each) for each in value.list_value]
elif isinstance(value, gRPC_comm_manager_pb2.mDict_keyIsString) or \
isinstance(value, gRPC_comm_manager_pb2.mDict_keyIsInt):
return {
k: self._parse_msg(value.dict_value[k])
for k in value.dict_value
}
else:
return value
def parse(self, received_msg):
self.sender = self._parse_msg(received_msg['sender'])
self.receiver = self._parse_msg(received_msg['receiver'])
self.msg_type = self._parse_msg(received_msg['msg_type'])
self.state = self._parse_msg(received_msg['state'])
self.content = self._parse_msg(received_msg['content'])
self.timestamp = self._parse_msg(received_msg['timestamp'])
def count_bytes(self):
"""
calculate the message bytes to be sent/received
:return: tuple of bytes of the message to be sent and received
"""
from pympler import asizeof
download_bytes = asizeof.asizeof(self.content)
upload_cnt = len(self.receiver) if isinstance(self.receiver,
list) else 1
upload_bytes = download_bytes * upload_cnt
return download_bytes, upload_bytes