from federatedscope.register import register_metric import numpy as np def compute_poison_metric(ctx): poison_true = ctx['poison_' + ctx.cur_split + '_y_true'] poison_prob = ctx['poison_' + ctx.cur_split + '_y_prob'] poison_pred = np.argmax(poison_prob, axis=1) correct = poison_true == poison_pred return float(np.sum(correct)) / len(correct) def load_poison_metrics(ctx, y_true, y_pred, y_prob, **kwargs): if ctx.cur_split == 'train': results = None else: results = compute_poison_metric(ctx) return results def call_poison_metric(types): if 'poison_attack_acc' in types: the_larger_the_better = True return 'poison_attack_acc', load_poison_metrics, the_larger_the_better register_metric('poison_attack_acc', call_poison_metric)