FS-TFP/benchmark/pFL-Bench/res_analysis_plot/render_paper_res.py

833 lines
29 KiB
Python

import copy
import json
import os
import numpy as np
import wandb
from collections import OrderedDict
import yaml
api = wandb.Api()
name_project = "daoyuan/pFL-bench"
filters_each_line_main_table = OrderedDict(
# {dataset_name: filter}
[
# ("all",
# None,
# ),
# ("FEMNIST-all",
# {"$and":
# [
# {"config.data.type": "femnist"},
# ]
# }
# ),
("FEMNIST-s02", {
"$and": [
{
"config.data.type": "femnist"
},
{
"config.federate.sample_client_rate": 0.2
},
{
"state": "finished"
},
]
}),
# ("cifar10-alpha05",
# {"$and":
# [
# {"config.data.type": "CIFAR10@torchvision"},
# {"config.data.splitter_args": [{"alpha": 0.5}]},
# ]
# }
# ),
("sst2", {
"$and": [
{
"config.data.type": "sst2@huggingface_datasets"
},
]
}),
("pubmed", {
"$and": [
{
"config.data.type": "pubmed"
},
]
}),
])
filters_each_line_all_cifar10 = OrderedDict(
# {dataset_name: filter}
[
("cifar10-alpha5", {
"$and": [
{
"config.data.type": "CIFAR10@torchvision"
},
{
"config.data.splitter_args": [{
"alpha": 5
}]
},
]
}),
("cifar10-alpha05", {
"$and": [
{
"config.data.type": "CIFAR10@torchvision"
},
{
"config.data.splitter_args": [{
"alpha": 0.5
}]
},
]
}),
("cifar10-alpha01", {
"$and": [
{
"config.data.type": "CIFAR10@torchvision"
},
{
"config.data.splitter_args": [{
"alpha": 0.1
}]
},
]
}),
])
filters_each_line_femnist_all_s = OrderedDict(
# {dataset_name: filter}
[
("FEMNIST-s02", {
"$and": [
{
"config.data.type": "femnist"
},
{
"config.federate.sample_client_rate": 0.2
},
{
"state": "finished"
},
]
}),
("FEMNIST-s01", {
"$and": [
{
"config.data.type": "femnist"
},
{
"config.federate.sample_client_rate": 0.1
},
{
"state": "finished"
},
]
}),
("FEMNIST-s005", {
"$and": [
{
"config.data.type": "femnist"
},
{
"config.federate.sample_client_rate": 0.05
},
{
"state": "finished"
},
]
}),
])
filters_each_line_all_graph = OrderedDict(
# {dataset_name: filter}
[
("pubmed", {
"$and": [
{
"config.data.type": "pubmed"
},
]
}),
("cora", {
"$and": [
{
"config.data.type": "cora"
},
]
}),
("citeseer", {
"$and": [
{
"config.data.type": "citeseer"
},
]
}),
])
filters_each_line_all_nlp = OrderedDict(
# {dataset_name: filter}
[
("cola", {
"$and": [
{
"config.data.type": "cola@huggingface_datasets"
},
]
}),
("sst2", {
"$and": [
{
"config.data.type": "sst2@huggingface_datasets"
},
]
}),
])
sweep_name_2_id = dict()
column_names_generalization = [
"best_client_summarized_weighted_avg/test_acc",
"best_unseen_client_summarized_weighted_avg_unseen/test_acc",
"participation_gap"
]
column_names_fair = [
"best_client_summarized_avg/test_acc",
"best_client_summarized_fairness/test_acc_std",
"best_client_summarized_fairness/test_acc_bottom_decile"
]
column_names_efficiency = [
"sys_avg/total_flops",
"sys_avg/total_upload_bytes",
"sys_avg/total_download_bytes",
"sys_avg/global_convergence_round",
# "sys_avg/local_convergence_round"
]
column_names_generalization_for_plot = [
"Acc (Parti.)", "Acc (Un-parti.)", "Generalization Gap"
]
column_name_for_plot = {
"best_client_summarized_weighted_avg/test_acc": "Acc (Parti.)",
"total_flops": "Total Flops",
"communication_bytes": "Communication Bytes",
"sys_avg/global_convergence_round": "Convergence Round",
}
sorted_method_name_pair = [
("global-train", "Global-Train"),
("isolated-train", "Isolated"),
("fedavg", "FedAvg"),
("fedavg-ft", "FedAvg-FT"),
("fedopt", "FedOpt"),
("fedopt-ft", "FedOpt-FT"),
("pfedme", "pFedMe"),
("ft-pfedme", "pFedMe-FT"),
("fedbn", "FedBN"),
("fedbn-ft", "FedBN-FT"),
("fedbn-fedopt", "FedBN-FedOPT"),
("fedbn-fedopt-ft", "FedBN-FedOPT-FT"),
("ditto", "Ditto"),
("ditto-ft", "Ditto-FT"),
("ditto-fedbn", "Ditto-FedBN"),
("ditto-fedbn-ft", "Ditto-FedBN-FT"),
("ditto-fedbn-fedopt", "Ditto-FedBN-FedOpt"),
("ditto-fedbn-fedopt-ft", "Ditto-FedBN-FedOpt-FT"),
("fedem", "FedEM"),
("fedem-ft", "FedEM-FT"),
("fedbn-fedem", "FedEM-FedBN"),
("fedbn-fedem-ft", "FedEM-FedBN-FT"),
("fedbn-fedem-fedopt", "FedEM-FedBN-FedOPT"),
("fedbn-fedem-fedopt-ft", "FedEM-FedBN-FedOPT-FT"),
]
sorted_keys = OrderedDict(sorted_method_name_pair)
expected_keys = set(list(sorted_keys.keys()))
expected_method_names = list(sorted_keys.values())
expected_datasets_name = [
"cola", "sst2", "pubmed", "cora", "citeseer", "cifar10-alpha5",
"cifar10-alpha05", "cifar10-alpha01", "FEMNIST-s02", "FEMNIST-s01",
"FEMNIST-s005"
]
expected_seed_set = ["1", "2", "3"]
expected_expname_tag = set()
original_method_names = [
"Global-Train", "Isolated", "FedAvg", "pFedMe", "FedBN", "Ditto", "FedEM"
]
for method_name in expected_method_names:
for dataset_name in expected_datasets_name:
for seed in expected_seed_set:
expected_expname_tag.add(
f"{method_name}_{dataset_name}_seed{seed}")
expected_expname_tag.add(f"{method_name}_{dataset_name}_repeat")
from collections import defaultdict
all_missing_scripts = defaultdict(list)
all_res_structed = defaultdict(dict)
for expname_tag in expected_expname_tag:
for metric in column_names_generalization + column_names_efficiency + column_names_fair:
if "repeat" in expname_tag:
all_res_structed[expname_tag][metric] = []
else:
all_res_structed[expname_tag][metric] = "-"
def load_best_repeat_res(filter_seed_set=None):
for expname_tag in expected_expname_tag:
filter = {
"$and": [
{
"config.expname_tag": expname_tag
},
]
}
filtered_runs = api.runs("pfl-bench-best-repeat", filters=filter)
method, dataname, seed = expname_tag.split("_")
finished_run_cnt = 0
for run in filtered_runs:
if run.state != "finished":
print(f"run {run} is not fished")
else:
finished_run_cnt += 1
for metric in column_names_generalization + column_names_efficiency + column_names_fair:
try:
if method in ["Isolated", "Global-Train"]:
skip_generalize = "unseen" in metric or metric == "participation_gap"
skip_global_fairness = method == "Global-Train" and "fairness" in metric
if skip_generalize or skip_global_fairness:
all_res_structed[expname_tag][metric] = "-"
continue
if metric == "participation_gap":
all_res_structed[expname_tag][metric] = all_res_structed[expname_tag][
"best_unseen_client_summarized_weighted_avg_unseen/test_acc"] - \
all_res_structed[expname_tag][
"best_client_summarized_weighted_avg/test_acc"]
else:
all_res_structed[expname_tag][
metric] = run.summary[metric]
except KeyError:
print("Something wrong")
print_missing = True
for seed in filter_seed_set:
if seed in expname_tag:
print_missing = False
if finished_run_cnt == 0 and print_missing:
print(f"Missing run {expname_tag})")
yaml_name = f"{method}_{dataname}.yaml"
if "Global" in method:
yaml_name = f"\'{yaml_name}\'"
expname_tag_new = expname_tag.replace("Global Train",
"Global-Train")
else:
expname_tag_new = expname_tag
seed_num = seed.replace("seed", "")
all_missing_scripts[seed].append(
f"python federatedscope/main.py --cfg scripts/personalization_exp_scripts/pfl_bench/yaml_best_runs/{yaml_name} seed {seed_num} expname_tag {expname_tag_new} wandb.name_project pfl-bench-best-repeat"
)
elif finished_run_cnt != 1 and print_missing:
print(f"run_cnt = {finished_run_cnt} for the exp {expname_tag}")
for seed in all_missing_scripts.keys():
print(
f"+================= All MISSING SCRIPTS, seed={seed} =====================+, cnt={len(all_missing_scripts[seed])}"
)
for scipt in all_missing_scripts[seed]:
print(scipt)
print()
def bytes_to_unit_size(size_bytes):
import math
if size_bytes == 0:
return "0"
size_name = ("", "K", "M", "G", "T", "P", "E", "Z", "Y")
i = int(math.floor(math.log(size_bytes, 1024)))
p = math.pow(1024, i)
s = round(size_bytes / p, 2)
return f"{s}{size_name[i]}"
def unit_size_to_bytes(size_str):
if not isinstance(size_str, str):
return size_str
else:
try:
last_unit = size_str[-1]
size_name = ("", "K", "M", "G", "T", "P", "E", "Z", "Y")
if last_unit not in size_name:
return float(size_str)
else:
# need transform
import math
idx = size_name.index(last_unit)
p = math.pow(1024, idx)
return float(size_str[:-1]) * p
except:
return size_str
def avg_res_of_seeds():
# add all res to repeat
for expname_tag in expected_expname_tag:
if "repeat" in expname_tag:
continue
else:
for metric in column_names_generalization + column_names_efficiency + column_names_fair:
if all_res_structed[expname_tag][
metric] == "-" and "Global" not in expname_tag and "Isolated" not in expname_tag:
print(f"missing {expname_tag} for metric {metric}")
method, dataname, seed = expname_tag.split("_")
cur_res = all_res_structed[expname_tag][metric]
all_res_structed[f"{method}_{dataname}_repeat"][metric].append(
cur_res)
for expname_tag in expected_expname_tag:
if "repeat" in expname_tag:
for metric in column_names_generalization + column_names_efficiency + column_names_fair:
valid_res = [
unit_size_to_bytes(v)
for v in all_res_structed[expname_tag][metric] if v != "-"
]
if len(valid_res) == 0:
all_res_structed[expname_tag][metric] = "-"
else:
res = sum(valid_res) / len(valid_res)
if "flops" in metric or "bytes" in metric:
res = bytes_to_unit_size(res)
all_res_structed[expname_tag][metric] = res
def highlight_tex_res_in_table(res_to_print_matrix_raw,
rank_order,
need_scale=False,
filter_out=None,
convergence_case=False):
res_to_print_matrix = []
if filter_out is not None:
# filter out the Global-Train and Isolated
for line in res_to_print_matrix_raw:
if line[0] in filter_out:
continue
else:
res_to_print_matrix.append(line)
else:
res_to_print_matrix = res_to_print_matrix_raw
res_np = np.array(res_to_print_matrix)
row_len, col_len = res_np.shape
if need_scale:
vfun = np.vectorize(unit_size_to_bytes)
res_np = vfun(res_np)
# select method idx
method_heads_all = res_np[:, :1]
selected_method_idx = [
i for i in range(row_len)
if method_heads_all[i] in original_method_names
]
raw_i_to_selected_i = {}
for idx_i, selected_i in enumerate(selected_method_idx):
raw_i_to_selected_i[selected_i] = idx_i
# render column by column
for col_i, col in enumerate(res_np[:, 1:].T):
# first replace the missing results into numerical res
if rank_order[col_i] == "+":
# order == "+" indicates the larger, the better
col = np.where(col == "-", -999999999, col)
else:
col = np.where(col == "-", 9999999999, col)
if convergence_case:
col = np.where(col == "0", 9999999999, col)
col = col.astype("float")
if rank_order[col_i] == "+":
col = -col
col_all = pd.DataFrame(col)
ind_all_method = col_all.rank(
method='dense').astype(int)[0].values.tolist()
col_filter = pd.DataFrame(col[selected_method_idx])
ind_partial_method_tmp = col_filter.rank(
method='dense').astype(int)[0].values.tolist()
for raw_i in range(row_len):
if ind_all_method[raw_i] == 1:
res_to_print_matrix[raw_i][
col_i +
1] = "\\textbf{" + res_to_print_matrix[raw_i][col_i +
1] + "}"
if ind_all_method[raw_i] == 2:
res_to_print_matrix[raw_i][
col_i +
1] = "\\underline{" + res_to_print_matrix[raw_i][col_i +
1] + "}"
if raw_i in selected_method_idx and ind_partial_method_tmp[
raw_i_to_selected_i[raw_i]] == 1:
res_to_print_matrix[raw_i][
col_i +
1] = "\\color{red}{" + res_to_print_matrix[raw_i][col_i +
1] + "}"
if raw_i in selected_method_idx and ind_partial_method_tmp[
raw_i_to_selected_i[raw_i]] == 2:
res_to_print_matrix[raw_i][
col_i +
1] = "\\color{blue}{" + res_to_print_matrix[raw_i][col_i +
1] + "}"
return res_to_print_matrix
def print_paper_table_from_repeat(filters_each_line_table):
res_of_each_line_generalization = OrderedDict()
res_of_each_line_fair = OrderedDict()
res_of_each_line_efficiency = OrderedDict()
res_of_each_line_commu_acc_trade = OrderedDict()
res_of_each_line_conver_acc_trade = OrderedDict()
for key in expected_method_names:
res_of_each_line_generalization[key] = []
res_of_each_line_fair[key] = []
res_of_each_line_efficiency[key] = []
for dataset_name in filters_each_line_table:
expname_tag = f"{key}_{dataset_name}_repeat"
for metric in column_names_generalization:
res_of_each_line_generalization[key].append(
all_res_structed[expname_tag][metric])
for metric in column_names_fair:
res_of_each_line_fair[key].append(
all_res_structed[expname_tag][metric])
for metric in column_names_efficiency:
res = all_res_structed[expname_tag][metric]
if "round" in metric:
res = "{:.2f}".format(res)
res_of_each_line_efficiency[key].append(res)
print("\n=============res_of_each_line [Generalization]===============" +
",".join(list(filters_each_line_table.keys())))
# Acc, Unseen-ACC, Delta
res_to_print_matrix = []
for key in expected_method_names:
res_to_print = [
"{:.2f}".format(v * 100) if v != "-" else v
for v in res_of_each_line_generalization[key]
]
res_to_print = [key] + res_to_print
res_to_print_matrix.append(res_to_print)
# print("&".join(res_to_print) + "\\\\")
colum_order_per_data = ["+", "+", "+"]
# "+" indicates the larger, the better
rank_order = colum_order_per_data * len(filters_each_line_table)
res_to_print_matrix = highlight_tex_res_in_table(res_to_print_matrix,
rank_order=rank_order)
for res_to_print in res_to_print_matrix:
print("&".join(res_to_print) + "\\\\")
print("\n=============res_of_each_line [Fairness]===============" +
",".join(list(filters_each_line_table.keys())))
res_to_print_matrix = []
for key in expected_method_names:
res_to_print = [
"{:.2f}".format(v * 100) if v != "-" else v
for v in res_of_each_line_fair[key]
]
res_to_print = [key] + res_to_print
res_to_print_matrix.append(res_to_print)
# print("&".join(res_to_print) + "\\\\")
colum_order_per_data = ["+", "-", "+"]
# "+" indicates the larger, the better
rank_order = colum_order_per_data * len(filters_each_line_table)
res_to_print_matrix = highlight_tex_res_in_table(
res_to_print_matrix,
rank_order=rank_order,
filter_out=["Global-Train"])
for res_to_print in res_to_print_matrix:
print("&".join(res_to_print) + "\\\\")
# print("\n=============res_of_each_line [All Efficiency]===============" + ",".join(
# list(filters_each_line_table.keys())))
## FLOPS, UPLOAD, DOWNLOAD
# for key in expected_method_names:
# res_to_print = [str(v) for v in res_of_each_line_efficiency[key]]
# res_to_print = [key] + res_to_print
# print("&".join(res_to_print) + "\\\\")
print(
"\n=============res_of_each_line [flops, communication, acc]==============="
+ ",".join(list(filters_each_line_table.keys())))
res_to_print_matrix = []
for key in expected_method_names:
res_of_each_line_commu_acc_trade[key] = []
dataset_num = 2 if "cola" in list(
filters_each_line_table.keys()) else 3
for i in range(dataset_num):
res_of_each_line_commu_acc_trade[key].extend(
[str(res_of_each_line_efficiency[key][i * 4])] + \
[str(res_of_each_line_efficiency[key][i * 4 + 1])] + \
["{:.2f}".format(v * 100) if v != "-" else v for v in
res_of_each_line_generalization[key][i * 3:i * 3 + 1]]
)
res_to_print = [str(v) for v in res_of_each_line_commu_acc_trade[key]]
res_to_print = [key] + res_to_print
res_to_print_matrix.append(res_to_print)
# print("&".join(res_to_print)+ "\\\\")
colum_order_per_data = ["-", "-", "+"]
# "+" indicates the larger, the better
rank_order = colum_order_per_data * len(filters_each_line_table)
res_to_print_matrix = highlight_tex_res_in_table(
res_to_print_matrix,
rank_order=rank_order,
need_scale=True,
filter_out=["Global-Train", "Isolated"])
for res_to_print in res_to_print_matrix:
print("&".join(res_to_print) + "\\\\")
print(
"\n=============res_of_each_line [converge_round, acc]==============="
+ ",".join(list(filters_each_line_table.keys())))
res_to_print_matrix = []
for key in expected_method_names:
res_of_each_line_conver_acc_trade[key] = []
dataset_num = 2 if "cola" in list(
filters_each_line_table.keys()) else 3
for i in range(dataset_num):
res_of_each_line_conver_acc_trade[key].extend(
[str(res_of_each_line_efficiency[key][i * 4 + 3])] + \
# [str(res_of_each_line_efficiency[key][i * 4 + 4])] + \
["{:.2f}".format(v * 100) if v != "-" else v for v in res_of_each_line_fair[key][i * 3:i * 3 + 1]]
)
res_to_print = [str(v) for v in res_of_each_line_conver_acc_trade[key]]
res_to_print = [key] + res_to_print
res_to_print_matrix.append(res_to_print)
# print("&".join(res_to_print) + "\\\\")
colum_order_per_data = ["-", "+"]
# "+" indicates the larger, the better
rank_order = colum_order_per_data * len(filters_each_line_table)
res_to_print_matrix = highlight_tex_res_in_table(
res_to_print_matrix,
rank_order=rank_order,
filter_out=["Global-Train", "Isolated"],
convergence_case=True)
for res_to_print in res_to_print_matrix:
print("&".join(res_to_print) + "\\\\")
import json
with open('best_res_all_metric.json', 'r') as fp:
all_res_structed_load = json.load(fp)
for expname_tag in expected_expname_tag:
if "repeat" in expname_tag:
continue
for metric in column_names_generalization + column_names_efficiency + column_names_fair:
all_res_structed[expname_tag][metric] = all_res_structed_load[
expname_tag][metric]
# add all res to a df
import pandas as pd
def load_data_to_pd(use_repeat_res=False):
all_res_for_pd = []
for expname_tag in expected_expname_tag:
if not use_repeat_res:
if "repeat" in expname_tag:
continue
else:
if not "repeat" in expname_tag:
continue
res = expname_tag.split("_") # method, data, seed
for metric in column_names_generalization + column_names_fair + column_names_efficiency:
res.append(all_res_structed[expname_tag][metric])
s = "-"
alpha = "-"
if "FEMNIST-s0" in res[1]:
s = float(res[1].replace("FEMNIST-s0", "0."))
if "cifar10-alpha0" in res[1]:
alpha = float(res[1].replace("cifar10-alpha0", "0."))
elif "cifar10-alpha" in res[1]:
alpha = float(res[1].replace("cifar10-alpha", ""))
res.append(s)
res.append(alpha)
total_com_bytes = unit_size_to_bytes(res[-5]) + unit_size_to_bytes(
res[-4])
total_flops = unit_size_to_bytes(res[-6])
res.append(total_com_bytes)
res.append(total_flops)
all_res_for_pd.append(res)
all_res_pd = pd.DataFrame().from_records(
all_res_for_pd,
columns=["method", "data", "seed"] + column_names_generalization +
column_names_fair + column_names_efficiency +
["s", "alpha", "communication_bytes", "total_flops"])
return all_res_pd
def plot_generalization_lines(all_res_pd, data_cate, data_cate_name):
import seaborn as sns
from matplotlib import pyplot as plt
import matplotlib.pylab as pylab
plt.clf()
sns.set()
fig, axes = plt.subplots(1, 3, figsize=(6, 4))
print(all_res_pd.columns.tolist())
plot_data = all_res_pd.loc[all_res_pd["data"].isin(data_cate)]
plot_data = plot_data.loc[plot_data["method"] != "Global-Train"]
plot_data = plot_data.loc[plot_data["method"] != "Isolated"]
plot_data = plot_data.loc[plot_data["method"] != "FedOpt"]
plot_data = plot_data.loc[plot_data["method"] != "FedOpt-FT"]
filter_out_methods = ["Global-Train", "Isolated", "FedOpt", "FedOpt-FT"]
for i, metric in enumerate(column_names_generalization):
plt.clf()
sns.set()
fig, axes = plt.subplots(1, 1, figsize=(2, 3))
x = "data"
if data_cate_name == "femnist_all":
x = "s"
if data_cate_name == "cifar10_all":
x = "alpha"
ax = sns.lineplot(
ax=axes,
data=plot_data,
x=x,
y=metric,
hue="method",
style="method",
markers=True,
dashes=True,
hue_order=[
m for m in expected_method_names if m not in filter_out_methods
],
sort=True,
)
ax.set(ylabel=column_names_generalization_for_plot[i])
plt.gca().invert_xaxis()
if data_cate_name == "cifar10_all":
ax.set_xscale('log')
plt.legend(bbox_to_anchor=(1, 1), loc=2, ncol=2, borderaxespad=0.)
plt.tight_layout()
plt.savefig(f"generalization_all_{data_cate_name}_{i}.pdf",
bbox_inches='tight',
pad_inches=0)
plt.show()
def plot_tradeoff(all_res_pd, data_cate, data_cate_name, metric_a, metric_b,
fig_time):
import seaborn as sns
from matplotlib import pyplot as plt
import matplotlib.pylab as pylab
plt.clf()
sns.set()
print(all_res_pd.columns.tolist())
plot_data = all_res_pd.loc[all_res_pd["data"].isin(data_cate)]
plot_data = plot_data.loc[plot_data["method"] != "Global-Train"]
plot_data = plot_data.loc[plot_data["method"] != "Isolated"]
plot_data = plot_data.loc[plot_data["method"] != "FedOpt"]
plot_data = plot_data.loc[plot_data["method"] != "FedOpt-FT"]
filter_out_methods = ["Global-Train", "Isolated", "FedOpt", "FedOpt-FT"]
plt.clf()
sns.set()
fig, axes = plt.subplots(1, 1, figsize=(2, 3))
ax = sns.scatterplot(ax=axes,
data=plot_data,
x=metric_a,
y=metric_b,
hue="method",
style="method",
markers=True,
hue_order=[
m for m in expected_method_names
if m not in filter_out_methods
],
s=100)
ax.set(xlabel=column_name_for_plot[metric_a],
ylabel=column_name_for_plot[metric_b])
# plt.gca().invert_xaxis()
if metric_a == "total_flops":
ax.set_xscale('log')
if data_cate_name == "cifar10_all":
ax.set_xscale('log')
plt.legend(bbox_to_anchor=(1, 1), loc=2, ncol=2, borderaxespad=0.)
plt.tight_layout()
plt.savefig(f"{fig_time}_{data_cate_name}.pdf",
bbox_inches='tight',
pad_inches=0)
plt.show()
if __name__ == "__main__":
load_best_repeat_res(["1", "repeat"])
avg_res_of_seeds()
print_paper_table_from_repeat(filters_each_line_main_table)
print_paper_table_from_repeat(filters_each_line_femnist_all_s)
print_paper_table_from_repeat(filters_each_line_all_cifar10)
print_paper_table_from_repeat(filters_each_line_all_nlp)
print_paper_table_from_repeat(filters_each_line_all_graph)
all_res_pd = load_data_to_pd(use_repeat_res=False)
all_res_pd_repeat = load_data_to_pd(use_repeat_res=True)
def plot_line_figs():
plot_generalization_lines(all_res_pd,
list(filters_each_line_femnist_all_s.keys()),
data_cate_name="femnist_all")
plot_generalization_lines(all_res_pd,
list(filters_each_line_all_cifar10.keys()),
data_cate_name="cifar10_all")
def plot_trade_off_figs(filters_each_line_main_table):
for data_name in list(filters_each_line_main_table.keys()):
plot_tradeoff(all_res_pd_repeat, [data_name],
data_cate_name=data_name,
metric_a="communication_bytes",
metric_b="best_client_summarized_weighted_avg/test_acc",
fig_time="com-acc")
for data_name in list(filters_each_line_main_table.keys()):
plot_tradeoff(all_res_pd_repeat, [data_name],
data_cate_name=data_name,
metric_a="total_flops",
metric_b="best_client_summarized_weighted_avg/test_acc",
fig_time="flops-acc")
for data_name in list(filters_each_line_main_table.keys()):
plot_tradeoff(all_res_pd_repeat, [data_name],
data_cate_name=data_name,
metric_a="sys_avg/global_convergence_round",
metric_b="best_client_summarized_weighted_avg/test_acc",
fig_time="round-acc")