import copy import json import logging import os import re import time import yaml import numpy as np from datetime import datetime logger = logging.getLogger(__name__) class CustomFormatter(logging.Formatter): """Logging colored formatter, adapted from https://stackoverflow.com/a/56944256/3638629""" def __init__(self, fmt): super().__init__() grey = '\x1b[38;21m' blue = '\x1b[38;5;39m' yellow = "\x1b[33;20m" red = '\x1b[38;5;196m' bold_red = '\x1b[31;1m' reset = '\x1b[0m' self.FORMATS = { logging.DEBUG: grey + fmt + reset, logging.INFO: blue + fmt + reset, logging.WARNING: yellow + fmt + reset, logging.ERROR: red + fmt + reset, logging.CRITICAL: bold_red + fmt + reset } def format(self, record): log_fmt = self.FORMATS.get(record.levelno) formatter = logging.Formatter(log_fmt) return formatter.format(record) class LoggerPrecisionFilter(logging.Filter): def __init__(self, precision): super().__init__() self.print_precision = precision def str_round(self, match_res): return str(round(eval(match_res.group()), self.print_precision)) def filter(self, record): # use regex to find float numbers and round them to specified precision if not isinstance(record.msg, str): record.msg = str(record.msg) if record.msg != "": if re.search(r"([-+]?\d+\.\d+)", record.msg): record.msg = re.sub(r"([-+]?\d+\.\d+)", self.str_round, record.msg) return True def update_logger(cfg, clear_before_add=False): root_logger = logging.getLogger("federatedscope") # clear all existing handlers and add the default stream if clear_before_add: root_logger.handlers = [] handler = logging.StreamHandler() fmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" handler.setFormatter(CustomFormatter(fmt)) root_logger.addHandler(handler) # update level if cfg.verbose > 0: logging_level = logging.INFO else: logging_level = logging.WARN root_logger.warning("Skip DEBUG/INFO messages") root_logger.setLevel(logging_level) # ================ create outdir to save log, exp_config, models, etc,. if cfg.outdir == "": cfg.outdir = os.path.join(os.getcwd(), "exp") if cfg.expname == "": cfg.expname = f"{cfg.federate.method}_{cfg.model.type}_on" \ f"_{cfg.data.type}_lr{cfg.train.optimizer.lr}_lste" \ f"p{cfg.train.local_update_steps}" if cfg.expname_tag: cfg.expname = f"{cfg.expname}_{cfg.expname_tag}" cfg.outdir = os.path.join(cfg.outdir, cfg.expname) # if exist, make directory with given name and time if os.path.isdir(cfg.outdir) and os.path.exists(cfg.outdir): outdir = os.path.join(cfg.outdir, "sub_exp" + datetime.now().strftime('_%Y%m%d%H%M%S') ) # e.g., sub_exp_20220411030524 while os.path.exists(outdir): time.sleep(1) outdir = os.path.join( cfg.outdir, "sub_exp" + datetime.now().strftime('_%Y%m%d%H%M%S')) cfg.outdir = outdir # if not, make directory with given name os.makedirs(cfg.outdir) # create file handler which logs even debug messages fh = logging.FileHandler(os.path.join(cfg.outdir, 'exp_print.log')) fh.setLevel(logging.DEBUG) logger_formatter = logging.Formatter( "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s") fh.setFormatter(logger_formatter) root_logger.addHandler(fh) # set print precision for terse logging np.set_printoptions(precision=cfg.print_decimal_digits) precision_filter = LoggerPrecisionFilter(cfg.print_decimal_digits) # attach the filter to the fh handler to propagate the filter, since # "Filters, unlike levels and handlers, do not propagate", # ref https://stackoverflow.com/questions/6850798/why-doesnt-filter- # attached-to-the-root-logger-propagate-to-descendant-loggers for handler in root_logger.handlers: handler.addFilter(precision_filter) import socket root_logger.info(f"the current machine is at" f" {socket.gethostbyname(socket.gethostname())}") root_logger.info(f"the current dir is {os.getcwd()}") root_logger.info(f"the output dir is {cfg.outdir}") if cfg.wandb.use: import sys sys.stderr = sys.stdout # make both stderr and stdout sent to wandb # server init_wandb(cfg) def init_wandb(cfg): try: import wandb # on some linux machines, we may need "thread" init to avoid memory # leakage os.environ["WANDB_START_METHOD"] = "thread" except ImportError: logger.error("cfg.wandb.use=True but not install the wandb package") exit() dataset_name = cfg.data.type method_name = cfg.federate.method exp_name = cfg.expname tmp_cfg = copy.deepcopy(cfg) if tmp_cfg.is_frozen(): tmp_cfg.defrost() tmp_cfg.clear_aux_info( ) # in most cases, no need to save the cfg_check_funcs via wandb tmp_cfg.de_arguments() cfg_yaml = yaml.safe_load(tmp_cfg.dump()) wandb.init(project=cfg.wandb.name_project, entity=cfg.wandb.name_user, config=cfg_yaml, group=dataset_name, job_type=method_name, name=exp_name, notes=f"{method_name}, {exp_name}") def logfile_2_wandb_dict(exp_log_f, raw_out=True): """ parse the logfiles [exp_print.log, eval_results.log] into wandb_dict that contains non-nested dicts :param exp_log_f: opened exp_log file :param raw_out: True indicates "exp_print.log", otherwise indicates "eval_results.log", the difference is whether contains the logger header such as "2022-05-02 16:55:02,843 (client:197) INFO:" :return: tuple including (all_log_res, exp_stop_normal, last_line, log_res_best) """ log_res_best = {} exp_stop_normal = False all_log_res = [] last_line = None for line in exp_log_f: last_line = line exp_stop_normal, log_res = logline_2_wandb_dict( exp_stop_normal, line, log_res_best, raw_out) if "'Role': 'Server #'" in line: all_log_res.append(log_res) return all_log_res, exp_stop_normal, last_line, log_res_best def logline_2_wandb_dict(exp_stop_normal, line, log_res_best, raw_out): log_res = {} if "INFO:" in line and "Find new best result for" in line: # Logger type 1, each line for each metric, e.g., # 2022-03-22 10:48:42,562 (server:459) INFO: Find new best result # for client_best_individual.test_acc with value 0.5911787974683544 line = line.split("INFO: ")[1] parse_res = line.split("with value") best_key, best_val = parse_res[-2], parse_res[-1] # client_best_individual.test_acc -> client_best_individual/test_acc best_key = best_key.replace("Find new best result for", "").replace(".", "/") log_res_best[best_key.strip()] = float(best_val.strip()) if "Find new best result:" in line: # each line for all metric of a role, e.g., # Find new best result: {'Client #1': {'val_loss': # 132.9812364578247, 'test_total': 36, 'test_avg_loss': # 3.709533585442437, 'test_correct': 2.0, 'test_loss': # 133.54320907592773, 'test_acc': 0.05555555555555555, 'val_total': # 36, 'val_avg_loss': 3.693923234939575, 'val_correct': 4.0, # 'val_acc': 0.1111111111111111}} line = line.replace("Find new best result: ", "").replace("\'", "\"") res = json.loads(s=line) for best_type_key, val in res.items(): for inner_key, inner_val in val.items(): log_res_best[f"best_{best_type_key}/{inner_key}"] = inner_val if "'Role'" in line: if raw_out: line = line.split("INFO: ")[1] res = line.replace("\'", "\"") res = json.loads(s=res) # pre-process the roles cur_round = res['Round'] if "Server" in res['Role']: if cur_round != "Final" and 'Results_raw' in res: res.pop('Results_raw') role = res.pop('Role') # parse the k-v pairs for key, val in res.items(): if not isinstance(val, dict): log_res[f"{role}, {key}"] = val else: if cur_round != "Final": if key == "Results_raw": for key_inner, val_inner in res["Results_raw"].items(): log_res[f"{role}, {key_inner}"] = val_inner else: for key_inner, val_inner in val.items(): assert not isinstance(val_inner, dict), \ "Un-expected log format" log_res[f"{role}, {key}/{key_inner}"] = val_inner else: exp_stop_normal = True if key == "Results_raw": for final_type, final_type_dict in res[ "Results_raw"].items(): for inner_key, inner_val in final_type_dict.items( ): log_res_best[ f"{final_type}/{inner_key}"] = inner_val return exp_stop_normal, log_res