FS-TFP/federatedscope/core/auxiliaries/logging.py

256 lines
9.6 KiB
Python

import copy
import json
import logging
import os
import re
import time
import yaml
import numpy as np
from datetime import datetime
logger = logging.getLogger(__name__)
class CustomFormatter(logging.Formatter):
"""Logging colored formatter, adapted from
https://stackoverflow.com/a/56944256/3638629"""
def __init__(self, fmt):
super().__init__()
grey = '\x1b[38;21m'
blue = '\x1b[38;5;39m'
yellow = "\x1b[33;20m"
red = '\x1b[38;5;196m'
bold_red = '\x1b[31;1m'
reset = '\x1b[0m'
self.FORMATS = {
logging.DEBUG: grey + fmt + reset,
logging.INFO: blue + fmt + reset,
logging.WARNING: yellow + fmt + reset,
logging.ERROR: red + fmt + reset,
logging.CRITICAL: bold_red + fmt + reset
}
def format(self, record):
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt)
return formatter.format(record)
class LoggerPrecisionFilter(logging.Filter):
def __init__(self, precision):
super().__init__()
self.print_precision = precision
def str_round(self, match_res):
return str(round(eval(match_res.group()), self.print_precision))
def filter(self, record):
# use regex to find float numbers and round them to specified precision
if not isinstance(record.msg, str):
record.msg = str(record.msg)
if record.msg != "":
if re.search(r"([-+]?\d+\.\d+)", record.msg):
record.msg = re.sub(r"([-+]?\d+\.\d+)", self.str_round,
record.msg)
return True
def update_logger(cfg, clear_before_add=False):
root_logger = logging.getLogger("federatedscope")
# clear all existing handlers and add the default stream
if clear_before_add:
root_logger.handlers = []
handler = logging.StreamHandler()
fmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
handler.setFormatter(CustomFormatter(fmt))
root_logger.addHandler(handler)
# update level
if cfg.verbose > 0:
logging_level = logging.INFO
else:
logging_level = logging.WARN
root_logger.warning("Skip DEBUG/INFO messages")
root_logger.setLevel(logging_level)
# ================ create outdir to save log, exp_config, models, etc,.
if cfg.outdir == "":
cfg.outdir = os.path.join(os.getcwd(), "exp")
if cfg.expname == "":
cfg.expname = f"{cfg.federate.method}_{cfg.model.type}_on" \
f"_{cfg.data.type}_lr{cfg.train.optimizer.lr}_lste" \
f"p{cfg.train.local_update_steps}"
if cfg.expname_tag:
cfg.expname = f"{cfg.expname}_{cfg.expname_tag}"
cfg.outdir = os.path.join(cfg.outdir, cfg.expname)
# if exist, make directory with given name and time
if os.path.isdir(cfg.outdir) and os.path.exists(cfg.outdir):
outdir = os.path.join(cfg.outdir, "sub_exp" +
datetime.now().strftime('_%Y%m%d%H%M%S')
) # e.g., sub_exp_20220411030524
while os.path.exists(outdir):
time.sleep(1)
outdir = os.path.join(
cfg.outdir,
"sub_exp" + datetime.now().strftime('_%Y%m%d%H%M%S'))
cfg.outdir = outdir
# if not, make directory with given name
os.makedirs(cfg.outdir)
# create file handler which logs even debug messages
fh = logging.FileHandler(os.path.join(cfg.outdir, 'exp_print.log'))
fh.setLevel(logging.DEBUG)
logger_formatter = logging.Formatter(
"%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
fh.setFormatter(logger_formatter)
root_logger.addHandler(fh)
# set print precision for terse logging
np.set_printoptions(precision=cfg.print_decimal_digits)
precision_filter = LoggerPrecisionFilter(cfg.print_decimal_digits)
# attach the filter to the fh handler to propagate the filter, since
# "Filters, unlike levels and handlers, do not propagate",
# ref https://stackoverflow.com/questions/6850798/why-doesnt-filter-
# attached-to-the-root-logger-propagate-to-descendant-loggers
for handler in root_logger.handlers:
handler.addFilter(precision_filter)
import socket
root_logger.info(f"the current machine is at"
f" {socket.gethostbyname(socket.gethostname())}")
root_logger.info(f"the current dir is {os.getcwd()}")
root_logger.info(f"the output dir is {cfg.outdir}")
if cfg.wandb.use:
import sys
sys.stderr = sys.stdout # make both stderr and stdout sent to wandb
# server
init_wandb(cfg)
def init_wandb(cfg):
try:
import wandb
# on some linux machines, we may need "thread" init to avoid memory
# leakage
os.environ["WANDB_START_METHOD"] = "thread"
except ImportError:
logger.error("cfg.wandb.use=True but not install the wandb package")
exit()
dataset_name = cfg.data.type
method_name = cfg.federate.method
exp_name = cfg.expname
tmp_cfg = copy.deepcopy(cfg)
if tmp_cfg.is_frozen():
tmp_cfg.defrost()
tmp_cfg.clear_aux_info(
) # in most cases, no need to save the cfg_check_funcs via wandb
tmp_cfg.de_arguments()
cfg_yaml = yaml.safe_load(tmp_cfg.dump())
wandb.init(project=cfg.wandb.name_project,
entity=cfg.wandb.name_user,
config=cfg_yaml,
group=dataset_name,
job_type=method_name,
name=exp_name,
notes=f"{method_name}, {exp_name}")
def logfile_2_wandb_dict(exp_log_f, raw_out=True):
"""
parse the logfiles [exp_print.log, eval_results.log] into
wandb_dict that contains non-nested dicts
:param exp_log_f: opened exp_log file
:param raw_out: True indicates "exp_print.log", otherwise indicates
"eval_results.log",
the difference is whether contains the logger header such as
"2022-05-02 16:55:02,843 (client:197) INFO:"
:return: tuple including (all_log_res, exp_stop_normal, last_line,
log_res_best)
"""
log_res_best = {}
exp_stop_normal = False
all_log_res = []
last_line = None
for line in exp_log_f:
last_line = line
exp_stop_normal, log_res = logline_2_wandb_dict(
exp_stop_normal, line, log_res_best, raw_out)
if "'Role': 'Server #'" in line:
all_log_res.append(log_res)
return all_log_res, exp_stop_normal, last_line, log_res_best
def logline_2_wandb_dict(exp_stop_normal, line, log_res_best, raw_out):
log_res = {}
if "INFO:" in line and "Find new best result for" in line:
# Logger type 1, each line for each metric, e.g.,
# 2022-03-22 10:48:42,562 (server:459) INFO: Find new best result
# for client_best_individual.test_acc with value 0.5911787974683544
line = line.split("INFO: ")[1]
parse_res = line.split("with value")
best_key, best_val = parse_res[-2], parse_res[-1]
# client_best_individual.test_acc -> client_best_individual/test_acc
best_key = best_key.replace("Find new best result for",
"").replace(".", "/")
log_res_best[best_key.strip()] = float(best_val.strip())
if "Find new best result:" in line:
# each line for all metric of a role, e.g.,
# Find new best result: {'Client #1': {'val_loss':
# 132.9812364578247, 'test_total': 36, 'test_avg_loss':
# 3.709533585442437, 'test_correct': 2.0, 'test_loss':
# 133.54320907592773, 'test_acc': 0.05555555555555555, 'val_total':
# 36, 'val_avg_loss': 3.693923234939575, 'val_correct': 4.0,
# 'val_acc': 0.1111111111111111}}
line = line.replace("Find new best result: ", "").replace("\'", "\"")
res = json.loads(s=line)
for best_type_key, val in res.items():
for inner_key, inner_val in val.items():
log_res_best[f"best_{best_type_key}/{inner_key}"] = inner_val
if "'Role'" in line:
if raw_out:
line = line.split("INFO: ")[1]
res = line.replace("\'", "\"")
res = json.loads(s=res)
# pre-process the roles
cur_round = res['Round']
if "Server" in res['Role']:
if cur_round != "Final" and 'Results_raw' in res:
res.pop('Results_raw')
role = res.pop('Role')
# parse the k-v pairs
for key, val in res.items():
if not isinstance(val, dict):
log_res[f"{role}, {key}"] = val
else:
if cur_round != "Final":
if key == "Results_raw":
for key_inner, val_inner in res["Results_raw"].items():
log_res[f"{role}, {key_inner}"] = val_inner
else:
for key_inner, val_inner in val.items():
assert not isinstance(val_inner, dict), \
"Un-expected log format"
log_res[f"{role}, {key}/{key_inner}"] = val_inner
else:
exp_stop_normal = True
if key == "Results_raw":
for final_type, final_type_dict in res[
"Results_raw"].items():
for inner_key, inner_val in final_type_dict.items(
):
log_res_best[
f"{final_type}/{inner_key}"] = inner_val
return exp_stop_normal, log_res