class AverageMeter(object): def __init__(self): self.reset() def reset(self): self.avg = 0 self.sum = 0 self.cnt = 0 self.val = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.cnt += n self.avg = self.sum / self.cnt class ContrastiveMonitor(object): def __init__(self, stat=1, enc_hidden=None, synth_tokens=None, dec_hidden=None, dec_out=None, all_group_ids=None, topk_group_ids=None): self.stat = stat self.enc_hidden = enc_hidden self.synth_tokens = synth_tokens self.dec_hidden = dec_hidden self.dec_out = dec_out self.all_group_ids = all_group_ids self.topk_group_ids = topk_group_ids def update_stat(self, status): self.stat = status def update_all_group_ids(self, group_ids): self.all_group_ids = group_ids def update_topk_group_ids(self, group_ids): self.topk_group_ids = group_ids def update_enc_hidden(self, enc_hidden, k=None): if k is None: self.enc_hidden = enc_hidden else: if self.enc_hidden is None: self.enc_hidden = {} self.enc_hidden[k] = enc_hidden def update_synth_tokens(self, synth_tokens, k=None): if k is None: self.synth_tokens = synth_tokens else: if self.synth_tokens is None: self.synth_tokens = {} self.synth_tokens[k] = synth_tokens def update_dec_hidden(self, dec_hidden, k=None): if k is None: self.dec_hidden = dec_hidden else: if self.dec_hidden is None: self.dec_hidden = {} self.dec_hidden[k] = dec_hidden def update_dec_out(self, dec_out, k=None): if k is None: self.dec_out = dec_out else: if self.dec_out is None: self.dec_out = {} self.dec_out[k] = dec_out def reset(self): self.stat = 1 self.dec_hidden = None self.dec_out = None self.group_ids = None