65 lines
1.6 KiB
Python
65 lines
1.6 KiB
Python
import os
|
|
import torch
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
|
class Aggregator(ABC):
|
|
"""
|
|
Abstract class of Aggregator.
|
|
"""
|
|
def __init__(self):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def aggregate(self, agg_info):
|
|
"""
|
|
Aggregation function.
|
|
|
|
Args:
|
|
agg_info: information to be aggregated.
|
|
"""
|
|
pass
|
|
|
|
|
|
class NoCommunicationAggregator(Aggregator):
|
|
"""Clients do not communicate. Each client work locally
|
|
"""
|
|
def __init__(self, model=None, device='cpu', config=None):
|
|
super(Aggregator, self).__init__()
|
|
self.model = model
|
|
self.device = device
|
|
self.cfg = config
|
|
|
|
def update(self, model_parameters):
|
|
'''
|
|
Arguments:
|
|
model_parameters (dict): PyTorch Module object's state_dict.
|
|
'''
|
|
self.model.load_state_dict(model_parameters, strict=False)
|
|
|
|
def save_model(self, path, cur_round=-1):
|
|
assert self.model is not None
|
|
|
|
ckpt = {'cur_round': cur_round, 'model': self.model.state_dict()}
|
|
torch.save(ckpt, path)
|
|
|
|
def load_model(self, path):
|
|
assert self.model is not None
|
|
|
|
if os.path.exists(path):
|
|
ckpt = torch.load(path, map_location=self.device)
|
|
self.model.load_state_dict(ckpt['model'])
|
|
return ckpt['cur_round']
|
|
else:
|
|
raise ValueError("The file {} does NOT exist".format(path))
|
|
|
|
def aggregate(self, agg_info):
|
|
"""
|
|
Aggregation function.
|
|
|
|
Args:
|
|
agg_info: information to be aggregated.
|
|
"""
|
|
# do nothing
|
|
return {}
|