FS-TFP/federatedscope/nlp/hetero_tasks/trainer/utils.py

81 lines
2.2 KiB
Python

class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.avg = 0
self.sum = 0
self.cnt = 0
self.val = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.cnt += n
self.avg = self.sum / self.cnt
class ContrastiveMonitor(object):
def __init__(self,
stat=1,
enc_hidden=None,
synth_tokens=None,
dec_hidden=None,
dec_out=None,
all_group_ids=None,
topk_group_ids=None):
self.stat = stat
self.enc_hidden = enc_hidden
self.synth_tokens = synth_tokens
self.dec_hidden = dec_hidden
self.dec_out = dec_out
self.all_group_ids = all_group_ids
self.topk_group_ids = topk_group_ids
def update_stat(self, status):
self.stat = status
def update_all_group_ids(self, group_ids):
self.all_group_ids = group_ids
def update_topk_group_ids(self, group_ids):
self.topk_group_ids = group_ids
def update_enc_hidden(self, enc_hidden, k=None):
if k is None:
self.enc_hidden = enc_hidden
else:
if self.enc_hidden is None:
self.enc_hidden = {}
self.enc_hidden[k] = enc_hidden
def update_synth_tokens(self, synth_tokens, k=None):
if k is None:
self.synth_tokens = synth_tokens
else:
if self.synth_tokens is None:
self.synth_tokens = {}
self.synth_tokens[k] = synth_tokens
def update_dec_hidden(self, dec_hidden, k=None):
if k is None:
self.dec_hidden = dec_hidden
else:
if self.dec_hidden is None:
self.dec_hidden = {}
self.dec_hidden[k] = dec_hidden
def update_dec_out(self, dec_out, k=None):
if k is None:
self.dec_out = dec_out
else:
if self.dec_out is None:
self.dec_out = {}
self.dec_out[k] = dec_out
def reset(self):
self.stat = 1
self.dec_hidden = None
self.dec_out = None
self.group_ids = None