833 lines
29 KiB
Python
833 lines
29 KiB
Python
import copy
|
|
import json
|
|
import os
|
|
|
|
import numpy as np
|
|
import wandb
|
|
from collections import OrderedDict
|
|
|
|
import yaml
|
|
|
|
api = wandb.Api()
|
|
|
|
name_project = "daoyuan/pFL-bench"
|
|
|
|
filters_each_line_main_table = OrderedDict(
|
|
# {dataset_name: filter}
|
|
[
|
|
# ("all",
|
|
# None,
|
|
# ),
|
|
# ("FEMNIST-all",
|
|
# {"$and":
|
|
# [
|
|
# {"config.data.type": "femnist"},
|
|
# ]
|
|
# }
|
|
# ),
|
|
("FEMNIST-s02", {
|
|
"$and": [
|
|
{
|
|
"config.data.type": "femnist"
|
|
},
|
|
{
|
|
"config.federate.sample_client_rate": 0.2
|
|
},
|
|
{
|
|
"state": "finished"
|
|
},
|
|
]
|
|
}),
|
|
# ("cifar10-alpha05",
|
|
# {"$and":
|
|
# [
|
|
# {"config.data.type": "CIFAR10@torchvision"},
|
|
# {"config.data.splitter_args": [{"alpha": 0.5}]},
|
|
# ]
|
|
# }
|
|
# ),
|
|
("sst2", {
|
|
"$and": [
|
|
{
|
|
"config.data.type": "sst2@huggingface_datasets"
|
|
},
|
|
]
|
|
}),
|
|
("pubmed", {
|
|
"$and": [
|
|
{
|
|
"config.data.type": "pubmed"
|
|
},
|
|
]
|
|
}),
|
|
])
|
|
|
|
filters_each_line_all_cifar10 = OrderedDict(
|
|
# {dataset_name: filter}
|
|
[
|
|
("cifar10-alpha5", {
|
|
"$and": [
|
|
{
|
|
"config.data.type": "CIFAR10@torchvision"
|
|
},
|
|
{
|
|
"config.data.splitter_args": [{
|
|
"alpha": 5
|
|
}]
|
|
},
|
|
]
|
|
}),
|
|
("cifar10-alpha05", {
|
|
"$and": [
|
|
{
|
|
"config.data.type": "CIFAR10@torchvision"
|
|
},
|
|
{
|
|
"config.data.splitter_args": [{
|
|
"alpha": 0.5
|
|
}]
|
|
},
|
|
]
|
|
}),
|
|
("cifar10-alpha01", {
|
|
"$and": [
|
|
{
|
|
"config.data.type": "CIFAR10@torchvision"
|
|
},
|
|
{
|
|
"config.data.splitter_args": [{
|
|
"alpha": 0.1
|
|
}]
|
|
},
|
|
]
|
|
}),
|
|
])
|
|
|
|
filters_each_line_femnist_all_s = OrderedDict(
|
|
# {dataset_name: filter}
|
|
[
|
|
("FEMNIST-s02", {
|
|
"$and": [
|
|
{
|
|
"config.data.type": "femnist"
|
|
},
|
|
{
|
|
"config.federate.sample_client_rate": 0.2
|
|
},
|
|
{
|
|
"state": "finished"
|
|
},
|
|
]
|
|
}),
|
|
("FEMNIST-s01", {
|
|
"$and": [
|
|
{
|
|
"config.data.type": "femnist"
|
|
},
|
|
{
|
|
"config.federate.sample_client_rate": 0.1
|
|
},
|
|
{
|
|
"state": "finished"
|
|
},
|
|
]
|
|
}),
|
|
("FEMNIST-s005", {
|
|
"$and": [
|
|
{
|
|
"config.data.type": "femnist"
|
|
},
|
|
{
|
|
"config.federate.sample_client_rate": 0.05
|
|
},
|
|
{
|
|
"state": "finished"
|
|
},
|
|
]
|
|
}),
|
|
])
|
|
|
|
filters_each_line_all_graph = OrderedDict(
|
|
# {dataset_name: filter}
|
|
[
|
|
("pubmed", {
|
|
"$and": [
|
|
{
|
|
"config.data.type": "pubmed"
|
|
},
|
|
]
|
|
}),
|
|
("cora", {
|
|
"$and": [
|
|
{
|
|
"config.data.type": "cora"
|
|
},
|
|
]
|
|
}),
|
|
("citeseer", {
|
|
"$and": [
|
|
{
|
|
"config.data.type": "citeseer"
|
|
},
|
|
]
|
|
}),
|
|
])
|
|
|
|
filters_each_line_all_nlp = OrderedDict(
|
|
# {dataset_name: filter}
|
|
[
|
|
("cola", {
|
|
"$and": [
|
|
{
|
|
"config.data.type": "cola@huggingface_datasets"
|
|
},
|
|
]
|
|
}),
|
|
("sst2", {
|
|
"$and": [
|
|
{
|
|
"config.data.type": "sst2@huggingface_datasets"
|
|
},
|
|
]
|
|
}),
|
|
])
|
|
|
|
sweep_name_2_id = dict()
|
|
column_names_generalization = [
|
|
"best_client_summarized_weighted_avg/test_acc",
|
|
"best_unseen_client_summarized_weighted_avg_unseen/test_acc",
|
|
"participation_gap"
|
|
]
|
|
column_names_fair = [
|
|
"best_client_summarized_avg/test_acc",
|
|
"best_client_summarized_fairness/test_acc_std",
|
|
"best_client_summarized_fairness/test_acc_bottom_decile"
|
|
]
|
|
column_names_efficiency = [
|
|
"sys_avg/total_flops",
|
|
"sys_avg/total_upload_bytes",
|
|
"sys_avg/total_download_bytes",
|
|
"sys_avg/global_convergence_round",
|
|
# "sys_avg/local_convergence_round"
|
|
]
|
|
column_names_generalization_for_plot = [
|
|
"Acc (Parti.)", "Acc (Un-parti.)", "Generalization Gap"
|
|
]
|
|
column_name_for_plot = {
|
|
"best_client_summarized_weighted_avg/test_acc": "Acc (Parti.)",
|
|
"total_flops": "Total Flops",
|
|
"communication_bytes": "Communication Bytes",
|
|
"sys_avg/global_convergence_round": "Convergence Round",
|
|
}
|
|
sorted_method_name_pair = [
|
|
("global-train", "Global-Train"),
|
|
("isolated-train", "Isolated"),
|
|
("fedavg", "FedAvg"),
|
|
("fedavg-ft", "FedAvg-FT"),
|
|
("fedopt", "FedOpt"),
|
|
("fedopt-ft", "FedOpt-FT"),
|
|
("pfedme", "pFedMe"),
|
|
("ft-pfedme", "pFedMe-FT"),
|
|
("fedbn", "FedBN"),
|
|
("fedbn-ft", "FedBN-FT"),
|
|
("fedbn-fedopt", "FedBN-FedOPT"),
|
|
("fedbn-fedopt-ft", "FedBN-FedOPT-FT"),
|
|
("ditto", "Ditto"),
|
|
("ditto-ft", "Ditto-FT"),
|
|
("ditto-fedbn", "Ditto-FedBN"),
|
|
("ditto-fedbn-ft", "Ditto-FedBN-FT"),
|
|
("ditto-fedbn-fedopt", "Ditto-FedBN-FedOpt"),
|
|
("ditto-fedbn-fedopt-ft", "Ditto-FedBN-FedOpt-FT"),
|
|
("fedem", "FedEM"),
|
|
("fedem-ft", "FedEM-FT"),
|
|
("fedbn-fedem", "FedEM-FedBN"),
|
|
("fedbn-fedem-ft", "FedEM-FedBN-FT"),
|
|
("fedbn-fedem-fedopt", "FedEM-FedBN-FedOPT"),
|
|
("fedbn-fedem-fedopt-ft", "FedEM-FedBN-FedOPT-FT"),
|
|
]
|
|
sorted_keys = OrderedDict(sorted_method_name_pair)
|
|
expected_keys = set(list(sorted_keys.keys()))
|
|
expected_method_names = list(sorted_keys.values())
|
|
expected_datasets_name = [
|
|
"cola", "sst2", "pubmed", "cora", "citeseer", "cifar10-alpha5",
|
|
"cifar10-alpha05", "cifar10-alpha01", "FEMNIST-s02", "FEMNIST-s01",
|
|
"FEMNIST-s005"
|
|
]
|
|
expected_seed_set = ["1", "2", "3"]
|
|
expected_expname_tag = set()
|
|
|
|
original_method_names = [
|
|
"Global-Train", "Isolated", "FedAvg", "pFedMe", "FedBN", "Ditto", "FedEM"
|
|
]
|
|
|
|
for method_name in expected_method_names:
|
|
for dataset_name in expected_datasets_name:
|
|
for seed in expected_seed_set:
|
|
expected_expname_tag.add(
|
|
f"{method_name}_{dataset_name}_seed{seed}")
|
|
expected_expname_tag.add(f"{method_name}_{dataset_name}_repeat")
|
|
|
|
from collections import defaultdict
|
|
|
|
all_missing_scripts = defaultdict(list)
|
|
|
|
all_res_structed = defaultdict(dict)
|
|
for expname_tag in expected_expname_tag:
|
|
for metric in column_names_generalization + column_names_efficiency + column_names_fair:
|
|
if "repeat" in expname_tag:
|
|
all_res_structed[expname_tag][metric] = []
|
|
else:
|
|
all_res_structed[expname_tag][metric] = "-"
|
|
|
|
|
|
def load_best_repeat_res(filter_seed_set=None):
|
|
for expname_tag in expected_expname_tag:
|
|
filter = {
|
|
"$and": [
|
|
{
|
|
"config.expname_tag": expname_tag
|
|
},
|
|
]
|
|
}
|
|
filtered_runs = api.runs("pfl-bench-best-repeat", filters=filter)
|
|
method, dataname, seed = expname_tag.split("_")
|
|
finished_run_cnt = 0
|
|
for run in filtered_runs:
|
|
if run.state != "finished":
|
|
print(f"run {run} is not fished")
|
|
else:
|
|
finished_run_cnt += 1
|
|
for metric in column_names_generalization + column_names_efficiency + column_names_fair:
|
|
try:
|
|
if method in ["Isolated", "Global-Train"]:
|
|
skip_generalize = "unseen" in metric or metric == "participation_gap"
|
|
skip_global_fairness = method == "Global-Train" and "fairness" in metric
|
|
if skip_generalize or skip_global_fairness:
|
|
all_res_structed[expname_tag][metric] = "-"
|
|
continue
|
|
|
|
if metric == "participation_gap":
|
|
all_res_structed[expname_tag][metric] = all_res_structed[expname_tag][
|
|
"best_unseen_client_summarized_weighted_avg_unseen/test_acc"] - \
|
|
all_res_structed[expname_tag][
|
|
"best_client_summarized_weighted_avg/test_acc"]
|
|
else:
|
|
all_res_structed[expname_tag][
|
|
metric] = run.summary[metric]
|
|
except KeyError:
|
|
print("Something wrong")
|
|
|
|
print_missing = True
|
|
for seed in filter_seed_set:
|
|
if seed in expname_tag:
|
|
print_missing = False
|
|
if finished_run_cnt == 0 and print_missing:
|
|
print(f"Missing run {expname_tag})")
|
|
yaml_name = f"{method}_{dataname}.yaml"
|
|
if "Global" in method:
|
|
yaml_name = f"\'{yaml_name}\'"
|
|
expname_tag_new = expname_tag.replace("Global Train",
|
|
"Global-Train")
|
|
else:
|
|
expname_tag_new = expname_tag
|
|
seed_num = seed.replace("seed", "")
|
|
all_missing_scripts[seed].append(
|
|
f"python federatedscope/main.py --cfg scripts/personalization_exp_scripts/pfl_bench/yaml_best_runs/{yaml_name} seed {seed_num} expname_tag {expname_tag_new} wandb.name_project pfl-bench-best-repeat"
|
|
)
|
|
elif finished_run_cnt != 1 and print_missing:
|
|
print(f"run_cnt = {finished_run_cnt} for the exp {expname_tag}")
|
|
|
|
for seed in all_missing_scripts.keys():
|
|
print(
|
|
f"+================= All MISSING SCRIPTS, seed={seed} =====================+, cnt={len(all_missing_scripts[seed])}"
|
|
)
|
|
for scipt in all_missing_scripts[seed]:
|
|
print(scipt)
|
|
print()
|
|
|
|
|
|
def bytes_to_unit_size(size_bytes):
|
|
import math
|
|
if size_bytes == 0:
|
|
return "0"
|
|
size_name = ("", "K", "M", "G", "T", "P", "E", "Z", "Y")
|
|
i = int(math.floor(math.log(size_bytes, 1024)))
|
|
p = math.pow(1024, i)
|
|
s = round(size_bytes / p, 2)
|
|
return f"{s}{size_name[i]}"
|
|
|
|
|
|
def unit_size_to_bytes(size_str):
|
|
if not isinstance(size_str, str):
|
|
return size_str
|
|
else:
|
|
try:
|
|
last_unit = size_str[-1]
|
|
size_name = ("", "K", "M", "G", "T", "P", "E", "Z", "Y")
|
|
if last_unit not in size_name:
|
|
return float(size_str)
|
|
else:
|
|
# need transform
|
|
import math
|
|
idx = size_name.index(last_unit)
|
|
p = math.pow(1024, idx)
|
|
return float(size_str[:-1]) * p
|
|
except:
|
|
return size_str
|
|
|
|
|
|
def avg_res_of_seeds():
|
|
# add all res to repeat
|
|
for expname_tag in expected_expname_tag:
|
|
if "repeat" in expname_tag:
|
|
continue
|
|
else:
|
|
for metric in column_names_generalization + column_names_efficiency + column_names_fair:
|
|
if all_res_structed[expname_tag][
|
|
metric] == "-" and "Global" not in expname_tag and "Isolated" not in expname_tag:
|
|
print(f"missing {expname_tag} for metric {metric}")
|
|
method, dataname, seed = expname_tag.split("_")
|
|
cur_res = all_res_structed[expname_tag][metric]
|
|
all_res_structed[f"{method}_{dataname}_repeat"][metric].append(
|
|
cur_res)
|
|
|
|
for expname_tag in expected_expname_tag:
|
|
if "repeat" in expname_tag:
|
|
for metric in column_names_generalization + column_names_efficiency + column_names_fair:
|
|
valid_res = [
|
|
unit_size_to_bytes(v)
|
|
for v in all_res_structed[expname_tag][metric] if v != "-"
|
|
]
|
|
if len(valid_res) == 0:
|
|
all_res_structed[expname_tag][metric] = "-"
|
|
else:
|
|
res = sum(valid_res) / len(valid_res)
|
|
if "flops" in metric or "bytes" in metric:
|
|
res = bytes_to_unit_size(res)
|
|
all_res_structed[expname_tag][metric] = res
|
|
|
|
|
|
def highlight_tex_res_in_table(res_to_print_matrix_raw,
|
|
rank_order,
|
|
need_scale=False,
|
|
filter_out=None,
|
|
convergence_case=False):
|
|
res_to_print_matrix = []
|
|
if filter_out is not None:
|
|
# filter out the Global-Train and Isolated
|
|
for line in res_to_print_matrix_raw:
|
|
if line[0] in filter_out:
|
|
continue
|
|
else:
|
|
res_to_print_matrix.append(line)
|
|
else:
|
|
res_to_print_matrix = res_to_print_matrix_raw
|
|
|
|
res_np = np.array(res_to_print_matrix)
|
|
row_len, col_len = res_np.shape
|
|
|
|
if need_scale:
|
|
vfun = np.vectorize(unit_size_to_bytes)
|
|
res_np = vfun(res_np)
|
|
|
|
# select method idx
|
|
method_heads_all = res_np[:, :1]
|
|
selected_method_idx = [
|
|
i for i in range(row_len)
|
|
if method_heads_all[i] in original_method_names
|
|
]
|
|
|
|
raw_i_to_selected_i = {}
|
|
for idx_i, selected_i in enumerate(selected_method_idx):
|
|
raw_i_to_selected_i[selected_i] = idx_i
|
|
|
|
# render column by column
|
|
for col_i, col in enumerate(res_np[:, 1:].T):
|
|
# first replace the missing results into numerical res
|
|
if rank_order[col_i] == "+":
|
|
# order == "+" indicates the larger, the better
|
|
col = np.where(col == "-", -999999999, col)
|
|
else:
|
|
col = np.where(col == "-", 9999999999, col)
|
|
if convergence_case:
|
|
col = np.where(col == "0", 9999999999, col)
|
|
col = col.astype("float")
|
|
if rank_order[col_i] == "+":
|
|
col = -col
|
|
col_all = pd.DataFrame(col)
|
|
ind_all_method = col_all.rank(
|
|
method='dense').astype(int)[0].values.tolist()
|
|
col_filter = pd.DataFrame(col[selected_method_idx])
|
|
ind_partial_method_tmp = col_filter.rank(
|
|
method='dense').astype(int)[0].values.tolist()
|
|
for raw_i in range(row_len):
|
|
if ind_all_method[raw_i] == 1:
|
|
res_to_print_matrix[raw_i][
|
|
col_i +
|
|
1] = "\\textbf{" + res_to_print_matrix[raw_i][col_i +
|
|
1] + "}"
|
|
if ind_all_method[raw_i] == 2:
|
|
res_to_print_matrix[raw_i][
|
|
col_i +
|
|
1] = "\\underline{" + res_to_print_matrix[raw_i][col_i +
|
|
1] + "}"
|
|
if raw_i in selected_method_idx and ind_partial_method_tmp[
|
|
raw_i_to_selected_i[raw_i]] == 1:
|
|
res_to_print_matrix[raw_i][
|
|
col_i +
|
|
1] = "\\color{red}{" + res_to_print_matrix[raw_i][col_i +
|
|
1] + "}"
|
|
if raw_i in selected_method_idx and ind_partial_method_tmp[
|
|
raw_i_to_selected_i[raw_i]] == 2:
|
|
res_to_print_matrix[raw_i][
|
|
col_i +
|
|
1] = "\\color{blue}{" + res_to_print_matrix[raw_i][col_i +
|
|
1] + "}"
|
|
|
|
return res_to_print_matrix
|
|
|
|
|
|
def print_paper_table_from_repeat(filters_each_line_table):
|
|
res_of_each_line_generalization = OrderedDict()
|
|
res_of_each_line_fair = OrderedDict()
|
|
res_of_each_line_efficiency = OrderedDict()
|
|
res_of_each_line_commu_acc_trade = OrderedDict()
|
|
res_of_each_line_conver_acc_trade = OrderedDict()
|
|
|
|
for key in expected_method_names:
|
|
res_of_each_line_generalization[key] = []
|
|
res_of_each_line_fair[key] = []
|
|
res_of_each_line_efficiency[key] = []
|
|
for dataset_name in filters_each_line_table:
|
|
expname_tag = f"{key}_{dataset_name}_repeat"
|
|
for metric in column_names_generalization:
|
|
res_of_each_line_generalization[key].append(
|
|
all_res_structed[expname_tag][metric])
|
|
for metric in column_names_fair:
|
|
res_of_each_line_fair[key].append(
|
|
all_res_structed[expname_tag][metric])
|
|
for metric in column_names_efficiency:
|
|
res = all_res_structed[expname_tag][metric]
|
|
if "round" in metric:
|
|
res = "{:.2f}".format(res)
|
|
res_of_each_line_efficiency[key].append(res)
|
|
|
|
print("\n=============res_of_each_line [Generalization]===============" +
|
|
",".join(list(filters_each_line_table.keys())))
|
|
# Acc, Unseen-ACC, Delta
|
|
res_to_print_matrix = []
|
|
for key in expected_method_names:
|
|
res_to_print = [
|
|
"{:.2f}".format(v * 100) if v != "-" else v
|
|
for v in res_of_each_line_generalization[key]
|
|
]
|
|
res_to_print = [key] + res_to_print
|
|
res_to_print_matrix.append(res_to_print)
|
|
# print("&".join(res_to_print) + "\\\\")
|
|
|
|
colum_order_per_data = ["+", "+", "+"]
|
|
# "+" indicates the larger, the better
|
|
rank_order = colum_order_per_data * len(filters_each_line_table)
|
|
res_to_print_matrix = highlight_tex_res_in_table(res_to_print_matrix,
|
|
rank_order=rank_order)
|
|
for res_to_print in res_to_print_matrix:
|
|
print("&".join(res_to_print) + "\\\\")
|
|
|
|
print("\n=============res_of_each_line [Fairness]===============" +
|
|
",".join(list(filters_each_line_table.keys())))
|
|
res_to_print_matrix = []
|
|
for key in expected_method_names:
|
|
res_to_print = [
|
|
"{:.2f}".format(v * 100) if v != "-" else v
|
|
for v in res_of_each_line_fair[key]
|
|
]
|
|
res_to_print = [key] + res_to_print
|
|
res_to_print_matrix.append(res_to_print)
|
|
# print("&".join(res_to_print) + "\\\\")
|
|
|
|
colum_order_per_data = ["+", "-", "+"]
|
|
# "+" indicates the larger, the better
|
|
rank_order = colum_order_per_data * len(filters_each_line_table)
|
|
res_to_print_matrix = highlight_tex_res_in_table(
|
|
res_to_print_matrix,
|
|
rank_order=rank_order,
|
|
filter_out=["Global-Train"])
|
|
for res_to_print in res_to_print_matrix:
|
|
print("&".join(res_to_print) + "\\\\")
|
|
|
|
# print("\n=============res_of_each_line [All Efficiency]===============" + ",".join(
|
|
# list(filters_each_line_table.keys())))
|
|
## FLOPS, UPLOAD, DOWNLOAD
|
|
# for key in expected_method_names:
|
|
# res_to_print = [str(v) for v in res_of_each_line_efficiency[key]]
|
|
# res_to_print = [key] + res_to_print
|
|
# print("&".join(res_to_print) + "\\\\")
|
|
|
|
print(
|
|
"\n=============res_of_each_line [flops, communication, acc]==============="
|
|
+ ",".join(list(filters_each_line_table.keys())))
|
|
res_to_print_matrix = []
|
|
for key in expected_method_names:
|
|
res_of_each_line_commu_acc_trade[key] = []
|
|
dataset_num = 2 if "cola" in list(
|
|
filters_each_line_table.keys()) else 3
|
|
for i in range(dataset_num):
|
|
res_of_each_line_commu_acc_trade[key].extend(
|
|
[str(res_of_each_line_efficiency[key][i * 4])] + \
|
|
[str(res_of_each_line_efficiency[key][i * 4 + 1])] + \
|
|
["{:.2f}".format(v * 100) if v != "-" else v for v in
|
|
res_of_each_line_generalization[key][i * 3:i * 3 + 1]]
|
|
)
|
|
|
|
res_to_print = [str(v) for v in res_of_each_line_commu_acc_trade[key]]
|
|
res_to_print = [key] + res_to_print
|
|
res_to_print_matrix.append(res_to_print)
|
|
# print("&".join(res_to_print)+ "\\\\")
|
|
|
|
colum_order_per_data = ["-", "-", "+"]
|
|
# "+" indicates the larger, the better
|
|
rank_order = colum_order_per_data * len(filters_each_line_table)
|
|
res_to_print_matrix = highlight_tex_res_in_table(
|
|
res_to_print_matrix,
|
|
rank_order=rank_order,
|
|
need_scale=True,
|
|
filter_out=["Global-Train", "Isolated"])
|
|
for res_to_print in res_to_print_matrix:
|
|
print("&".join(res_to_print) + "\\\\")
|
|
|
|
print(
|
|
"\n=============res_of_each_line [converge_round, acc]==============="
|
|
+ ",".join(list(filters_each_line_table.keys())))
|
|
res_to_print_matrix = []
|
|
for key in expected_method_names:
|
|
res_of_each_line_conver_acc_trade[key] = []
|
|
dataset_num = 2 if "cola" in list(
|
|
filters_each_line_table.keys()) else 3
|
|
for i in range(dataset_num):
|
|
res_of_each_line_conver_acc_trade[key].extend(
|
|
[str(res_of_each_line_efficiency[key][i * 4 + 3])] + \
|
|
# [str(res_of_each_line_efficiency[key][i * 4 + 4])] + \
|
|
["{:.2f}".format(v * 100) if v != "-" else v for v in res_of_each_line_fair[key][i * 3:i * 3 + 1]]
|
|
)
|
|
|
|
res_to_print = [str(v) for v in res_of_each_line_conver_acc_trade[key]]
|
|
res_to_print = [key] + res_to_print
|
|
res_to_print_matrix.append(res_to_print)
|
|
# print("&".join(res_to_print) + "\\\\")
|
|
|
|
colum_order_per_data = ["-", "+"]
|
|
# "+" indicates the larger, the better
|
|
rank_order = colum_order_per_data * len(filters_each_line_table)
|
|
res_to_print_matrix = highlight_tex_res_in_table(
|
|
res_to_print_matrix,
|
|
rank_order=rank_order,
|
|
filter_out=["Global-Train", "Isolated"],
|
|
convergence_case=True)
|
|
for res_to_print in res_to_print_matrix:
|
|
print("&".join(res_to_print) + "\\\\")
|
|
|
|
|
|
import json
|
|
|
|
with open('best_res_all_metric.json', 'r') as fp:
|
|
all_res_structed_load = json.load(fp)
|
|
for expname_tag in expected_expname_tag:
|
|
if "repeat" in expname_tag:
|
|
continue
|
|
for metric in column_names_generalization + column_names_efficiency + column_names_fair:
|
|
all_res_structed[expname_tag][metric] = all_res_structed_load[
|
|
expname_tag][metric]
|
|
|
|
# add all res to a df
|
|
import pandas as pd
|
|
|
|
|
|
def load_data_to_pd(use_repeat_res=False):
|
|
all_res_for_pd = []
|
|
for expname_tag in expected_expname_tag:
|
|
if not use_repeat_res:
|
|
if "repeat" in expname_tag:
|
|
continue
|
|
else:
|
|
if not "repeat" in expname_tag:
|
|
continue
|
|
res = expname_tag.split("_") # method, data, seed
|
|
for metric in column_names_generalization + column_names_fair + column_names_efficiency:
|
|
res.append(all_res_structed[expname_tag][metric])
|
|
s = "-"
|
|
alpha = "-"
|
|
if "FEMNIST-s0" in res[1]:
|
|
s = float(res[1].replace("FEMNIST-s0", "0."))
|
|
if "cifar10-alpha0" in res[1]:
|
|
alpha = float(res[1].replace("cifar10-alpha0", "0."))
|
|
elif "cifar10-alpha" in res[1]:
|
|
alpha = float(res[1].replace("cifar10-alpha", ""))
|
|
res.append(s)
|
|
res.append(alpha)
|
|
total_com_bytes = unit_size_to_bytes(res[-5]) + unit_size_to_bytes(
|
|
res[-4])
|
|
total_flops = unit_size_to_bytes(res[-6])
|
|
res.append(total_com_bytes)
|
|
res.append(total_flops)
|
|
all_res_for_pd.append(res)
|
|
|
|
all_res_pd = pd.DataFrame().from_records(
|
|
all_res_for_pd,
|
|
columns=["method", "data", "seed"] + column_names_generalization +
|
|
column_names_fair + column_names_efficiency +
|
|
["s", "alpha", "communication_bytes", "total_flops"])
|
|
return all_res_pd
|
|
|
|
|
|
def plot_generalization_lines(all_res_pd, data_cate, data_cate_name):
|
|
import seaborn as sns
|
|
from matplotlib import pyplot as plt
|
|
import matplotlib.pylab as pylab
|
|
|
|
plt.clf()
|
|
sns.set()
|
|
fig, axes = plt.subplots(1, 3, figsize=(6, 4))
|
|
print(all_res_pd.columns.tolist())
|
|
|
|
plot_data = all_res_pd.loc[all_res_pd["data"].isin(data_cate)]
|
|
|
|
plot_data = plot_data.loc[plot_data["method"] != "Global-Train"]
|
|
plot_data = plot_data.loc[plot_data["method"] != "Isolated"]
|
|
plot_data = plot_data.loc[plot_data["method"] != "FedOpt"]
|
|
plot_data = plot_data.loc[plot_data["method"] != "FedOpt-FT"]
|
|
filter_out_methods = ["Global-Train", "Isolated", "FedOpt", "FedOpt-FT"]
|
|
for i, metric in enumerate(column_names_generalization):
|
|
plt.clf()
|
|
sns.set()
|
|
fig, axes = plt.subplots(1, 1, figsize=(2, 3))
|
|
x = "data"
|
|
if data_cate_name == "femnist_all":
|
|
x = "s"
|
|
if data_cate_name == "cifar10_all":
|
|
x = "alpha"
|
|
|
|
ax = sns.lineplot(
|
|
ax=axes,
|
|
data=plot_data,
|
|
x=x,
|
|
y=metric,
|
|
hue="method",
|
|
style="method",
|
|
markers=True,
|
|
dashes=True,
|
|
hue_order=[
|
|
m for m in expected_method_names if m not in filter_out_methods
|
|
],
|
|
sort=True,
|
|
)
|
|
ax.set(ylabel=column_names_generalization_for_plot[i])
|
|
plt.gca().invert_xaxis()
|
|
|
|
if data_cate_name == "cifar10_all":
|
|
ax.set_xscale('log')
|
|
|
|
plt.legend(bbox_to_anchor=(1, 1), loc=2, ncol=2, borderaxespad=0.)
|
|
plt.tight_layout()
|
|
plt.savefig(f"generalization_all_{data_cate_name}_{i}.pdf",
|
|
bbox_inches='tight',
|
|
pad_inches=0)
|
|
|
|
plt.show()
|
|
|
|
|
|
def plot_tradeoff(all_res_pd, data_cate, data_cate_name, metric_a, metric_b,
|
|
fig_time):
|
|
import seaborn as sns
|
|
from matplotlib import pyplot as plt
|
|
import matplotlib.pylab as pylab
|
|
|
|
plt.clf()
|
|
sns.set()
|
|
print(all_res_pd.columns.tolist())
|
|
|
|
plot_data = all_res_pd.loc[all_res_pd["data"].isin(data_cate)]
|
|
|
|
plot_data = plot_data.loc[plot_data["method"] != "Global-Train"]
|
|
plot_data = plot_data.loc[plot_data["method"] != "Isolated"]
|
|
plot_data = plot_data.loc[plot_data["method"] != "FedOpt"]
|
|
plot_data = plot_data.loc[plot_data["method"] != "FedOpt-FT"]
|
|
filter_out_methods = ["Global-Train", "Isolated", "FedOpt", "FedOpt-FT"]
|
|
plt.clf()
|
|
sns.set()
|
|
fig, axes = plt.subplots(1, 1, figsize=(2, 3))
|
|
|
|
ax = sns.scatterplot(ax=axes,
|
|
data=plot_data,
|
|
x=metric_a,
|
|
y=metric_b,
|
|
hue="method",
|
|
style="method",
|
|
markers=True,
|
|
hue_order=[
|
|
m for m in expected_method_names
|
|
if m not in filter_out_methods
|
|
],
|
|
s=100)
|
|
ax.set(xlabel=column_name_for_plot[metric_a],
|
|
ylabel=column_name_for_plot[metric_b])
|
|
# plt.gca().invert_xaxis()
|
|
if metric_a == "total_flops":
|
|
ax.set_xscale('log')
|
|
|
|
if data_cate_name == "cifar10_all":
|
|
ax.set_xscale('log')
|
|
|
|
plt.legend(bbox_to_anchor=(1, 1), loc=2, ncol=2, borderaxespad=0.)
|
|
plt.tight_layout()
|
|
plt.savefig(f"{fig_time}_{data_cate_name}.pdf",
|
|
bbox_inches='tight',
|
|
pad_inches=0)
|
|
|
|
plt.show()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
load_best_repeat_res(["1", "repeat"])
|
|
avg_res_of_seeds()
|
|
|
|
print_paper_table_from_repeat(filters_each_line_main_table)
|
|
print_paper_table_from_repeat(filters_each_line_femnist_all_s)
|
|
print_paper_table_from_repeat(filters_each_line_all_cifar10)
|
|
print_paper_table_from_repeat(filters_each_line_all_nlp)
|
|
print_paper_table_from_repeat(filters_each_line_all_graph)
|
|
|
|
all_res_pd = load_data_to_pd(use_repeat_res=False)
|
|
all_res_pd_repeat = load_data_to_pd(use_repeat_res=True)
|
|
|
|
|
|
def plot_line_figs():
|
|
plot_generalization_lines(all_res_pd,
|
|
list(filters_each_line_femnist_all_s.keys()),
|
|
data_cate_name="femnist_all")
|
|
plot_generalization_lines(all_res_pd,
|
|
list(filters_each_line_all_cifar10.keys()),
|
|
data_cate_name="cifar10_all")
|
|
|
|
|
|
def plot_trade_off_figs(filters_each_line_main_table):
|
|
for data_name in list(filters_each_line_main_table.keys()):
|
|
plot_tradeoff(all_res_pd_repeat, [data_name],
|
|
data_cate_name=data_name,
|
|
metric_a="communication_bytes",
|
|
metric_b="best_client_summarized_weighted_avg/test_acc",
|
|
fig_time="com-acc")
|
|
|
|
for data_name in list(filters_each_line_main_table.keys()):
|
|
plot_tradeoff(all_res_pd_repeat, [data_name],
|
|
data_cate_name=data_name,
|
|
metric_a="total_flops",
|
|
metric_b="best_client_summarized_weighted_avg/test_acc",
|
|
fig_time="flops-acc")
|
|
|
|
for data_name in list(filters_each_line_main_table.keys()):
|
|
plot_tradeoff(all_res_pd_repeat, [data_name],
|
|
data_cate_name=data_name,
|
|
metric_a="sys_avg/global_convergence_round",
|
|
metric_b="best_client_summarized_weighted_avg/test_acc",
|
|
fig_time="round-acc")
|