import copy import json import os import numpy as np import wandb from collections import OrderedDict import yaml api = wandb.Api() name_project = "daoyuan/pFL-bench" filters_each_line_main_table = OrderedDict( # {dataset_name: filter} [ # ("all", # None, # ), # ("FEMNIST-all", # {"$and": # [ # {"config.data.type": "femnist"}, # ] # } # ), ("FEMNIST-s02", { "$and": [ { "config.data.type": "femnist" }, { "config.federate.sample_client_rate": 0.2 }, { "state": "finished" }, ] }), # ("cifar10-alpha05", # {"$and": # [ # {"config.data.type": "CIFAR10@torchvision"}, # {"config.data.splitter_args": [{"alpha": 0.5}]}, # ] # } # ), ("sst2", { "$and": [ { "config.data.type": "sst2@huggingface_datasets" }, ] }), ("pubmed", { "$and": [ { "config.data.type": "pubmed" }, ] }), ]) filters_each_line_all_cifar10 = OrderedDict( # {dataset_name: filter} [ ("cifar10-alpha5", { "$and": [ { "config.data.type": "CIFAR10@torchvision" }, { "config.data.splitter_args": [{ "alpha": 5 }] }, ] }), ("cifar10-alpha05", { "$and": [ { "config.data.type": "CIFAR10@torchvision" }, { "config.data.splitter_args": [{ "alpha": 0.5 }] }, ] }), ("cifar10-alpha01", { "$and": [ { "config.data.type": "CIFAR10@torchvision" }, { "config.data.splitter_args": [{ "alpha": 0.1 }] }, ] }), ]) filters_each_line_femnist_all_s = OrderedDict( # {dataset_name: filter} [ ("FEMNIST-s02", { "$and": [ { "config.data.type": "femnist" }, { "config.federate.sample_client_rate": 0.2 }, { "state": "finished" }, ] }), ("FEMNIST-s01", { "$and": [ { "config.data.type": "femnist" }, { "config.federate.sample_client_rate": 0.1 }, { "state": "finished" }, ] }), ("FEMNIST-s005", { "$and": [ { "config.data.type": "femnist" }, { "config.federate.sample_client_rate": 0.05 }, { "state": "finished" }, ] }), ]) filters_each_line_all_graph = OrderedDict( # {dataset_name: filter} [ ("pubmed", { "$and": [ { "config.data.type": "pubmed" }, ] }), ("cora", { "$and": [ { "config.data.type": "cora" }, ] }), ("citeseer", { "$and": [ { "config.data.type": "citeseer" }, ] }), ]) filters_each_line_all_nlp = OrderedDict( # {dataset_name: filter} [ ("cola", { "$and": [ { "config.data.type": "cola@huggingface_datasets" }, ] }), ("sst2", { "$and": [ { "config.data.type": "sst2@huggingface_datasets" }, ] }), ]) sweep_name_2_id = dict() column_names_generalization = [ "best_client_summarized_weighted_avg/test_acc", "best_unseen_client_summarized_weighted_avg_unseen/test_acc", "participation_gap" ] column_names_fair = [ "best_client_summarized_avg/test_acc", "best_client_summarized_fairness/test_acc_std", "best_client_summarized_fairness/test_acc_bottom_decile" ] column_names_efficiency = [ "sys_avg/total_flops", "sys_avg/total_upload_bytes", "sys_avg/total_download_bytes", "sys_avg/global_convergence_round", # "sys_avg/local_convergence_round" ] column_names_generalization_for_plot = [ "Acc (Parti.)", "Acc (Un-parti.)", "Generalization Gap" ] column_name_for_plot = { "best_client_summarized_weighted_avg/test_acc": "Acc (Parti.)", "total_flops": "Total Flops", "communication_bytes": "Communication Bytes", "sys_avg/global_convergence_round": "Convergence Round", } sorted_method_name_pair = [ ("global-train", "Global-Train"), ("isolated-train", "Isolated"), ("fedavg", "FedAvg"), ("fedavg-ft", "FedAvg-FT"), ("fedopt", "FedOpt"), ("fedopt-ft", "FedOpt-FT"), ("pfedme", "pFedMe"), ("ft-pfedme", "pFedMe-FT"), ("fedbn", "FedBN"), ("fedbn-ft", "FedBN-FT"), ("fedbn-fedopt", "FedBN-FedOPT"), ("fedbn-fedopt-ft", "FedBN-FedOPT-FT"), ("ditto", "Ditto"), ("ditto-ft", "Ditto-FT"), ("ditto-fedbn", "Ditto-FedBN"), ("ditto-fedbn-ft", "Ditto-FedBN-FT"), ("ditto-fedbn-fedopt", "Ditto-FedBN-FedOpt"), ("ditto-fedbn-fedopt-ft", "Ditto-FedBN-FedOpt-FT"), ("fedem", "FedEM"), ("fedem-ft", "FedEM-FT"), ("fedbn-fedem", "FedEM-FedBN"), ("fedbn-fedem-ft", "FedEM-FedBN-FT"), ("fedbn-fedem-fedopt", "FedEM-FedBN-FedOPT"), ("fedbn-fedem-fedopt-ft", "FedEM-FedBN-FedOPT-FT"), ] sorted_keys = OrderedDict(sorted_method_name_pair) expected_keys = set(list(sorted_keys.keys())) expected_method_names = list(sorted_keys.values()) expected_datasets_name = [ "cola", "sst2", "pubmed", "cora", "citeseer", "cifar10-alpha5", "cifar10-alpha05", "cifar10-alpha01", "FEMNIST-s02", "FEMNIST-s01", "FEMNIST-s005" ] expected_seed_set = ["1", "2", "3"] expected_expname_tag = set() original_method_names = [ "Global-Train", "Isolated", "FedAvg", "pFedMe", "FedBN", "Ditto", "FedEM" ] for method_name in expected_method_names: for dataset_name in expected_datasets_name: for seed in expected_seed_set: expected_expname_tag.add( f"{method_name}_{dataset_name}_seed{seed}") expected_expname_tag.add(f"{method_name}_{dataset_name}_repeat") from collections import defaultdict all_missing_scripts = defaultdict(list) all_res_structed = defaultdict(dict) for expname_tag in expected_expname_tag: for metric in column_names_generalization + column_names_efficiency + column_names_fair: if "repeat" in expname_tag: all_res_structed[expname_tag][metric] = [] else: all_res_structed[expname_tag][metric] = "-" def load_best_repeat_res(filter_seed_set=None): for expname_tag in expected_expname_tag: filter = { "$and": [ { "config.expname_tag": expname_tag }, ] } filtered_runs = api.runs("pfl-bench-best-repeat", filters=filter) method, dataname, seed = expname_tag.split("_") finished_run_cnt = 0 for run in filtered_runs: if run.state != "finished": print(f"run {run} is not fished") else: finished_run_cnt += 1 for metric in column_names_generalization + column_names_efficiency + column_names_fair: try: if method in ["Isolated", "Global-Train"]: skip_generalize = "unseen" in metric or metric == "participation_gap" skip_global_fairness = method == "Global-Train" and "fairness" in metric if skip_generalize or skip_global_fairness: all_res_structed[expname_tag][metric] = "-" continue if metric == "participation_gap": all_res_structed[expname_tag][metric] = all_res_structed[expname_tag][ "best_unseen_client_summarized_weighted_avg_unseen/test_acc"] - \ all_res_structed[expname_tag][ "best_client_summarized_weighted_avg/test_acc"] else: all_res_structed[expname_tag][ metric] = run.summary[metric] except KeyError: print("Something wrong") print_missing = True for seed in filter_seed_set: if seed in expname_tag: print_missing = False if finished_run_cnt == 0 and print_missing: print(f"Missing run {expname_tag})") yaml_name = f"{method}_{dataname}.yaml" if "Global" in method: yaml_name = f"\'{yaml_name}\'" expname_tag_new = expname_tag.replace("Global Train", "Global-Train") else: expname_tag_new = expname_tag seed_num = seed.replace("seed", "") all_missing_scripts[seed].append( f"python federatedscope/main.py --cfg scripts/personalization_exp_scripts/pfl_bench/yaml_best_runs/{yaml_name} seed {seed_num} expname_tag {expname_tag_new} wandb.name_project pfl-bench-best-repeat" ) elif finished_run_cnt != 1 and print_missing: print(f"run_cnt = {finished_run_cnt} for the exp {expname_tag}") for seed in all_missing_scripts.keys(): print( f"+================= All MISSING SCRIPTS, seed={seed} =====================+, cnt={len(all_missing_scripts[seed])}" ) for scipt in all_missing_scripts[seed]: print(scipt) print() def bytes_to_unit_size(size_bytes): import math if size_bytes == 0: return "0" size_name = ("", "K", "M", "G", "T", "P", "E", "Z", "Y") i = int(math.floor(math.log(size_bytes, 1024))) p = math.pow(1024, i) s = round(size_bytes / p, 2) return f"{s}{size_name[i]}" def unit_size_to_bytes(size_str): if not isinstance(size_str, str): return size_str else: try: last_unit = size_str[-1] size_name = ("", "K", "M", "G", "T", "P", "E", "Z", "Y") if last_unit not in size_name: return float(size_str) else: # need transform import math idx = size_name.index(last_unit) p = math.pow(1024, idx) return float(size_str[:-1]) * p except: return size_str def avg_res_of_seeds(): # add all res to repeat for expname_tag in expected_expname_tag: if "repeat" in expname_tag: continue else: for metric in column_names_generalization + column_names_efficiency + column_names_fair: if all_res_structed[expname_tag][ metric] == "-" and "Global" not in expname_tag and "Isolated" not in expname_tag: print(f"missing {expname_tag} for metric {metric}") method, dataname, seed = expname_tag.split("_") cur_res = all_res_structed[expname_tag][metric] all_res_structed[f"{method}_{dataname}_repeat"][metric].append( cur_res) for expname_tag in expected_expname_tag: if "repeat" in expname_tag: for metric in column_names_generalization + column_names_efficiency + column_names_fair: valid_res = [ unit_size_to_bytes(v) for v in all_res_structed[expname_tag][metric] if v != "-" ] if len(valid_res) == 0: all_res_structed[expname_tag][metric] = "-" else: res = sum(valid_res) / len(valid_res) if "flops" in metric or "bytes" in metric: res = bytes_to_unit_size(res) all_res_structed[expname_tag][metric] = res def highlight_tex_res_in_table(res_to_print_matrix_raw, rank_order, need_scale=False, filter_out=None, convergence_case=False): res_to_print_matrix = [] if filter_out is not None: # filter out the Global-Train and Isolated for line in res_to_print_matrix_raw: if line[0] in filter_out: continue else: res_to_print_matrix.append(line) else: res_to_print_matrix = res_to_print_matrix_raw res_np = np.array(res_to_print_matrix) row_len, col_len = res_np.shape if need_scale: vfun = np.vectorize(unit_size_to_bytes) res_np = vfun(res_np) # select method idx method_heads_all = res_np[:, :1] selected_method_idx = [ i for i in range(row_len) if method_heads_all[i] in original_method_names ] raw_i_to_selected_i = {} for idx_i, selected_i in enumerate(selected_method_idx): raw_i_to_selected_i[selected_i] = idx_i # render column by column for col_i, col in enumerate(res_np[:, 1:].T): # first replace the missing results into numerical res if rank_order[col_i] == "+": # order == "+" indicates the larger, the better col = np.where(col == "-", -999999999, col) else: col = np.where(col == "-", 9999999999, col) if convergence_case: col = np.where(col == "0", 9999999999, col) col = col.astype("float") if rank_order[col_i] == "+": col = -col col_all = pd.DataFrame(col) ind_all_method = col_all.rank( method='dense').astype(int)[0].values.tolist() col_filter = pd.DataFrame(col[selected_method_idx]) ind_partial_method_tmp = col_filter.rank( method='dense').astype(int)[0].values.tolist() for raw_i in range(row_len): if ind_all_method[raw_i] == 1: res_to_print_matrix[raw_i][ col_i + 1] = "\\textbf{" + res_to_print_matrix[raw_i][col_i + 1] + "}" if ind_all_method[raw_i] == 2: res_to_print_matrix[raw_i][ col_i + 1] = "\\underline{" + res_to_print_matrix[raw_i][col_i + 1] + "}" if raw_i in selected_method_idx and ind_partial_method_tmp[ raw_i_to_selected_i[raw_i]] == 1: res_to_print_matrix[raw_i][ col_i + 1] = "\\color{red}{" + res_to_print_matrix[raw_i][col_i + 1] + "}" if raw_i in selected_method_idx and ind_partial_method_tmp[ raw_i_to_selected_i[raw_i]] == 2: res_to_print_matrix[raw_i][ col_i + 1] = "\\color{blue}{" + res_to_print_matrix[raw_i][col_i + 1] + "}" return res_to_print_matrix def print_paper_table_from_repeat(filters_each_line_table): res_of_each_line_generalization = OrderedDict() res_of_each_line_fair = OrderedDict() res_of_each_line_efficiency = OrderedDict() res_of_each_line_commu_acc_trade = OrderedDict() res_of_each_line_conver_acc_trade = OrderedDict() for key in expected_method_names: res_of_each_line_generalization[key] = [] res_of_each_line_fair[key] = [] res_of_each_line_efficiency[key] = [] for dataset_name in filters_each_line_table: expname_tag = f"{key}_{dataset_name}_repeat" for metric in column_names_generalization: res_of_each_line_generalization[key].append( all_res_structed[expname_tag][metric]) for metric in column_names_fair: res_of_each_line_fair[key].append( all_res_structed[expname_tag][metric]) for metric in column_names_efficiency: res = all_res_structed[expname_tag][metric] if "round" in metric: res = "{:.2f}".format(res) res_of_each_line_efficiency[key].append(res) print("\n=============res_of_each_line [Generalization]===============" + ",".join(list(filters_each_line_table.keys()))) # Acc, Unseen-ACC, Delta res_to_print_matrix = [] for key in expected_method_names: res_to_print = [ "{:.2f}".format(v * 100) if v != "-" else v for v in res_of_each_line_generalization[key] ] res_to_print = [key] + res_to_print res_to_print_matrix.append(res_to_print) # print("&".join(res_to_print) + "\\\\") colum_order_per_data = ["+", "+", "+"] # "+" indicates the larger, the better rank_order = colum_order_per_data * len(filters_each_line_table) res_to_print_matrix = highlight_tex_res_in_table(res_to_print_matrix, rank_order=rank_order) for res_to_print in res_to_print_matrix: print("&".join(res_to_print) + "\\\\") print("\n=============res_of_each_line [Fairness]===============" + ",".join(list(filters_each_line_table.keys()))) res_to_print_matrix = [] for key in expected_method_names: res_to_print = [ "{:.2f}".format(v * 100) if v != "-" else v for v in res_of_each_line_fair[key] ] res_to_print = [key] + res_to_print res_to_print_matrix.append(res_to_print) # print("&".join(res_to_print) + "\\\\") colum_order_per_data = ["+", "-", "+"] # "+" indicates the larger, the better rank_order = colum_order_per_data * len(filters_each_line_table) res_to_print_matrix = highlight_tex_res_in_table( res_to_print_matrix, rank_order=rank_order, filter_out=["Global-Train"]) for res_to_print in res_to_print_matrix: print("&".join(res_to_print) + "\\\\") # print("\n=============res_of_each_line [All Efficiency]===============" + ",".join( # list(filters_each_line_table.keys()))) ## FLOPS, UPLOAD, DOWNLOAD # for key in expected_method_names: # res_to_print = [str(v) for v in res_of_each_line_efficiency[key]] # res_to_print = [key] + res_to_print # print("&".join(res_to_print) + "\\\\") print( "\n=============res_of_each_line [flops, communication, acc]===============" + ",".join(list(filters_each_line_table.keys()))) res_to_print_matrix = [] for key in expected_method_names: res_of_each_line_commu_acc_trade[key] = [] dataset_num = 2 if "cola" in list( filters_each_line_table.keys()) else 3 for i in range(dataset_num): res_of_each_line_commu_acc_trade[key].extend( [str(res_of_each_line_efficiency[key][i * 4])] + \ [str(res_of_each_line_efficiency[key][i * 4 + 1])] + \ ["{:.2f}".format(v * 100) if v != "-" else v for v in res_of_each_line_generalization[key][i * 3:i * 3 + 1]] ) res_to_print = [str(v) for v in res_of_each_line_commu_acc_trade[key]] res_to_print = [key] + res_to_print res_to_print_matrix.append(res_to_print) # print("&".join(res_to_print)+ "\\\\") colum_order_per_data = ["-", "-", "+"] # "+" indicates the larger, the better rank_order = colum_order_per_data * len(filters_each_line_table) res_to_print_matrix = highlight_tex_res_in_table( res_to_print_matrix, rank_order=rank_order, need_scale=True, filter_out=["Global-Train", "Isolated"]) for res_to_print in res_to_print_matrix: print("&".join(res_to_print) + "\\\\") print( "\n=============res_of_each_line [converge_round, acc]===============" + ",".join(list(filters_each_line_table.keys()))) res_to_print_matrix = [] for key in expected_method_names: res_of_each_line_conver_acc_trade[key] = [] dataset_num = 2 if "cola" in list( filters_each_line_table.keys()) else 3 for i in range(dataset_num): res_of_each_line_conver_acc_trade[key].extend( [str(res_of_each_line_efficiency[key][i * 4 + 3])] + \ # [str(res_of_each_line_efficiency[key][i * 4 + 4])] + \ ["{:.2f}".format(v * 100) if v != "-" else v for v in res_of_each_line_fair[key][i * 3:i * 3 + 1]] ) res_to_print = [str(v) for v in res_of_each_line_conver_acc_trade[key]] res_to_print = [key] + res_to_print res_to_print_matrix.append(res_to_print) # print("&".join(res_to_print) + "\\\\") colum_order_per_data = ["-", "+"] # "+" indicates the larger, the better rank_order = colum_order_per_data * len(filters_each_line_table) res_to_print_matrix = highlight_tex_res_in_table( res_to_print_matrix, rank_order=rank_order, filter_out=["Global-Train", "Isolated"], convergence_case=True) for res_to_print in res_to_print_matrix: print("&".join(res_to_print) + "\\\\") import json with open('best_res_all_metric.json', 'r') as fp: all_res_structed_load = json.load(fp) for expname_tag in expected_expname_tag: if "repeat" in expname_tag: continue for metric in column_names_generalization + column_names_efficiency + column_names_fair: all_res_structed[expname_tag][metric] = all_res_structed_load[ expname_tag][metric] # add all res to a df import pandas as pd def load_data_to_pd(use_repeat_res=False): all_res_for_pd = [] for expname_tag in expected_expname_tag: if not use_repeat_res: if "repeat" in expname_tag: continue else: if not "repeat" in expname_tag: continue res = expname_tag.split("_") # method, data, seed for metric in column_names_generalization + column_names_fair + column_names_efficiency: res.append(all_res_structed[expname_tag][metric]) s = "-" alpha = "-" if "FEMNIST-s0" in res[1]: s = float(res[1].replace("FEMNIST-s0", "0.")) if "cifar10-alpha0" in res[1]: alpha = float(res[1].replace("cifar10-alpha0", "0.")) elif "cifar10-alpha" in res[1]: alpha = float(res[1].replace("cifar10-alpha", "")) res.append(s) res.append(alpha) total_com_bytes = unit_size_to_bytes(res[-5]) + unit_size_to_bytes( res[-4]) total_flops = unit_size_to_bytes(res[-6]) res.append(total_com_bytes) res.append(total_flops) all_res_for_pd.append(res) all_res_pd = pd.DataFrame().from_records( all_res_for_pd, columns=["method", "data", "seed"] + column_names_generalization + column_names_fair + column_names_efficiency + ["s", "alpha", "communication_bytes", "total_flops"]) return all_res_pd def plot_generalization_lines(all_res_pd, data_cate, data_cate_name): import seaborn as sns from matplotlib import pyplot as plt import matplotlib.pylab as pylab plt.clf() sns.set() fig, axes = plt.subplots(1, 3, figsize=(6, 4)) print(all_res_pd.columns.tolist()) plot_data = all_res_pd.loc[all_res_pd["data"].isin(data_cate)] plot_data = plot_data.loc[plot_data["method"] != "Global-Train"] plot_data = plot_data.loc[plot_data["method"] != "Isolated"] plot_data = plot_data.loc[plot_data["method"] != "FedOpt"] plot_data = plot_data.loc[plot_data["method"] != "FedOpt-FT"] filter_out_methods = ["Global-Train", "Isolated", "FedOpt", "FedOpt-FT"] for i, metric in enumerate(column_names_generalization): plt.clf() sns.set() fig, axes = plt.subplots(1, 1, figsize=(2, 3)) x = "data" if data_cate_name == "femnist_all": x = "s" if data_cate_name == "cifar10_all": x = "alpha" ax = sns.lineplot( ax=axes, data=plot_data, x=x, y=metric, hue="method", style="method", markers=True, dashes=True, hue_order=[ m for m in expected_method_names if m not in filter_out_methods ], sort=True, ) ax.set(ylabel=column_names_generalization_for_plot[i]) plt.gca().invert_xaxis() if data_cate_name == "cifar10_all": ax.set_xscale('log') plt.legend(bbox_to_anchor=(1, 1), loc=2, ncol=2, borderaxespad=0.) plt.tight_layout() plt.savefig(f"generalization_all_{data_cate_name}_{i}.pdf", bbox_inches='tight', pad_inches=0) plt.show() def plot_tradeoff(all_res_pd, data_cate, data_cate_name, metric_a, metric_b, fig_time): import seaborn as sns from matplotlib import pyplot as plt import matplotlib.pylab as pylab plt.clf() sns.set() print(all_res_pd.columns.tolist()) plot_data = all_res_pd.loc[all_res_pd["data"].isin(data_cate)] plot_data = plot_data.loc[plot_data["method"] != "Global-Train"] plot_data = plot_data.loc[plot_data["method"] != "Isolated"] plot_data = plot_data.loc[plot_data["method"] != "FedOpt"] plot_data = plot_data.loc[plot_data["method"] != "FedOpt-FT"] filter_out_methods = ["Global-Train", "Isolated", "FedOpt", "FedOpt-FT"] plt.clf() sns.set() fig, axes = plt.subplots(1, 1, figsize=(2, 3)) ax = sns.scatterplot(ax=axes, data=plot_data, x=metric_a, y=metric_b, hue="method", style="method", markers=True, hue_order=[ m for m in expected_method_names if m not in filter_out_methods ], s=100) ax.set(xlabel=column_name_for_plot[metric_a], ylabel=column_name_for_plot[metric_b]) # plt.gca().invert_xaxis() if metric_a == "total_flops": ax.set_xscale('log') if data_cate_name == "cifar10_all": ax.set_xscale('log') plt.legend(bbox_to_anchor=(1, 1), loc=2, ncol=2, borderaxespad=0.) plt.tight_layout() plt.savefig(f"{fig_time}_{data_cate_name}.pdf", bbox_inches='tight', pad_inches=0) plt.show() if __name__ == "__main__": load_best_repeat_res(["1", "repeat"]) avg_res_of_seeds() print_paper_table_from_repeat(filters_each_line_main_table) print_paper_table_from_repeat(filters_each_line_femnist_all_s) print_paper_table_from_repeat(filters_each_line_all_cifar10) print_paper_table_from_repeat(filters_each_line_all_nlp) print_paper_table_from_repeat(filters_each_line_all_graph) all_res_pd = load_data_to_pd(use_repeat_res=False) all_res_pd_repeat = load_data_to_pd(use_repeat_res=True) def plot_line_figs(): plot_generalization_lines(all_res_pd, list(filters_each_line_femnist_all_s.keys()), data_cate_name="femnist_all") plot_generalization_lines(all_res_pd, list(filters_each_line_all_cifar10.keys()), data_cate_name="cifar10_all") def plot_trade_off_figs(filters_each_line_main_table): for data_name in list(filters_each_line_main_table.keys()): plot_tradeoff(all_res_pd_repeat, [data_name], data_cate_name=data_name, metric_a="communication_bytes", metric_b="best_client_summarized_weighted_avg/test_acc", fig_time="com-acc") for data_name in list(filters_each_line_main_table.keys()): plot_tradeoff(all_res_pd_repeat, [data_name], data_cate_name=data_name, metric_a="total_flops", metric_b="best_client_summarized_weighted_avg/test_acc", fig_time="flops-acc") for data_name in list(filters_each_line_main_table.keys()): plot_tradeoff(all_res_pd_repeat, [data_name], data_cate_name=data_name, metric_a="sys_avg/global_convergence_round", metric_b="best_client_summarized_weighted_avg/test_acc", fig_time="round-acc")