81 lines
2.2 KiB
Python
81 lines
2.2 KiB
Python
class AverageMeter(object):
|
|
def __init__(self):
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.avg = 0
|
|
self.sum = 0
|
|
self.cnt = 0
|
|
self.val = 0
|
|
|
|
def update(self, val, n=1):
|
|
self.val = val
|
|
self.sum += val * n
|
|
self.cnt += n
|
|
self.avg = self.sum / self.cnt
|
|
|
|
|
|
class ContrastiveMonitor(object):
|
|
def __init__(self,
|
|
stat=1,
|
|
enc_hidden=None,
|
|
synth_tokens=None,
|
|
dec_hidden=None,
|
|
dec_out=None,
|
|
all_group_ids=None,
|
|
topk_group_ids=None):
|
|
self.stat = stat
|
|
self.enc_hidden = enc_hidden
|
|
self.synth_tokens = synth_tokens
|
|
self.dec_hidden = dec_hidden
|
|
self.dec_out = dec_out
|
|
self.all_group_ids = all_group_ids
|
|
self.topk_group_ids = topk_group_ids
|
|
|
|
def update_stat(self, status):
|
|
self.stat = status
|
|
|
|
def update_all_group_ids(self, group_ids):
|
|
self.all_group_ids = group_ids
|
|
|
|
def update_topk_group_ids(self, group_ids):
|
|
self.topk_group_ids = group_ids
|
|
|
|
def update_enc_hidden(self, enc_hidden, k=None):
|
|
if k is None:
|
|
self.enc_hidden = enc_hidden
|
|
else:
|
|
if self.enc_hidden is None:
|
|
self.enc_hidden = {}
|
|
self.enc_hidden[k] = enc_hidden
|
|
|
|
def update_synth_tokens(self, synth_tokens, k=None):
|
|
if k is None:
|
|
self.synth_tokens = synth_tokens
|
|
else:
|
|
if self.synth_tokens is None:
|
|
self.synth_tokens = {}
|
|
self.synth_tokens[k] = synth_tokens
|
|
|
|
def update_dec_hidden(self, dec_hidden, k=None):
|
|
if k is None:
|
|
self.dec_hidden = dec_hidden
|
|
else:
|
|
if self.dec_hidden is None:
|
|
self.dec_hidden = {}
|
|
self.dec_hidden[k] = dec_hidden
|
|
|
|
def update_dec_out(self, dec_out, k=None):
|
|
if k is None:
|
|
self.dec_out = dec_out
|
|
else:
|
|
if self.dec_out is None:
|
|
self.dec_out = {}
|
|
self.dec_out[k] = dec_out
|
|
|
|
def reset(self):
|
|
self.stat = 1
|
|
self.dec_hidden = None
|
|
self.dec_out = None
|
|
self.group_ids = None
|