FS-TFP/federatedscope/autotune/fedex/utils.py

88 lines
2.7 KiB
Python

from torch import nn
from torch.nn.utils import spectral_norm
from federatedscope.autotune.utils import arm2dict
class EncNet(nn.Module):
def __init__(self, in_channel, out_channel, hid_dim=64):
super(EncNet, self).__init__()
self.fc_layer = nn.Sequential(
spectral_norm(nn.Linear(in_channel, hid_dim, bias=False)),
nn.ReLU(inplace=True),
spectral_norm(nn.Linear(hid_dim, out_channel, bias=False)),
nn.ReLU(inplace=True),
)
def forward(self, client_enc):
mean_update = self.fc_layer(client_enc)
return mean_update
class HyperNet(nn.Module):
def __init__(
self,
input_dim,
sizes,
n_clients,
device,
):
super(HyperNet, self).__init__()
self.EncNet = EncNet(input_dim, 32)
self.out = nn.ModuleList()
for num_cate in sizes:
self.out.append(
nn.Sequential(nn.Linear(32, num_cate, bias=True),
nn.Softmax()))
def forward(self, encoding):
client_enc = self.EncNet(encoding)
probs = []
for module in self.out:
out = module(client_enc)
probs.append(out)
return probs
if __name__ == "__main__":
import yaml
import argparse
import torch
parser = argparse.ArgumentParser(description='Interpret learned policy')
parser.add_argument('--ss_path', type=str, default='')
parser.add_argument('--log_path', type=str, default='')
parser.add_argument('--pt_path', type=str, default='')
parser.add_argument('--save_path', type=str, default='')
args = parser.parse_args()
with open(args.ss_path, 'r') as ips:
arms = yaml.load(ips, Loader=yaml.FullLoader)
print(arms)
with open(args.log_path, 'r') as ips:
ckpt = yaml.load(ips, Loader=yaml.FullLoader)
stop_exploration = ckpt['stop']
print("stop: {}".format(stop_exploration))
psn_pi = torch.load(args.pt_path, map_location='cpu')
client_encodings = psn_pi['client_encodings']
policy_net = HyperNet(
client_encodings.shape[-1],
[len(arms)],
client_encodings.shape[0],
'cpu',
).to('cpu')
policy_net.load_state_dict(psn_pi['policy_net'])
policy_net.eval()
prbs = policy_net(client_encodings)
prbs = prbs[0].detach().numpy()
clientwise_configs = dict()
for i in range(prbs.shape[0]):
arm_idx = prbs[i].argmax()
clientwise_configs['client_{}'.format(i + 1)] = arm2dict(
arms['arm{}'.format(arm_idx)])
with open(args.save_path, 'w') as ops:
yaml.Dumper.ignore_aliases = lambda *args: True
yaml.dump(clientwise_configs, ops)