import os import copy import logging import re import torch import numpy as np import codecs from tqdm import tqdm from collections import OrderedDict from torch.utils.data import DataLoader from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer from federatedscope.core.auxiliaries.scheduler_builder import get_scheduler from federatedscope.core.trainers import GeneralTorchTrainer from federatedscope.core.trainers.context import lifecycle, CtxVar from federatedscope.core.trainers.enums import LIFECYCLE, MODE from federatedscope.core.trainers.utils import filter_by_specified_keywords from federatedscope.core.monitors import MetricCalculator from federatedscope.core.monitors.metric_calculator import eval_acc from federatedscope.nlp.hetero_tasks.trainer.utils import AverageMeter, \ ContrastiveMonitor from federatedscope.nlp.hetero_tasks.dataset.utils import setup_tokenizer from federatedscope.nlp.hetero_tasks.dataset.squad import SquadResult from federatedscope.nlp.hetero_tasks.dataset.newsqa import NewsQAResult logger = logging.getLogger(__name__) class ATCTrainer(GeneralTorchTrainer): def __init__(self, model, data, device, config, only_for_eval=False, monitor=None): super().__init__(model, data, device, config, only_for_eval, monitor) self.metric_calculator = MetricCalculator(config.eval.metrics) self.task = config.model.task self.ID = None self.load_ckpt = True self.pred_file, self.src_file, self.tgt_file = None, None, None self.finish_eval = False self.ctx.eval_metrics = None self.ctx.tokenizer = setup_tokenizer(config.model.model_type) self.ctx.grad_accum_count = config.grad.grad_accum_count self.ctx.padding_idx = self.ctx.tokenizer.pad_token_id self.ctx.init_params = copy.deepcopy(model.state_dict()) self.pretrain_task = None self.use_contrastive_loss = config.model.use_contrastive_loss self.ctx.contrast_monitor = ContrastiveMonitor() if \ self.use_contrastive_loss else None def update(self, model_parameters, strict=False): super().update(model_parameters, strict=strict) self.ctx.init_params = copy.deepcopy(self.ctx.model.state_dict()) def update_pretrain_task(self, task): self.pretrain_task = task def update_stat(self, ID): self.ID = ID if self.task in {'cnndm', 'msqg'}: pred_dir = os.path.join(self.cfg.outdir, 'pred') src_dir = os.path.join(self.cfg.outdir, 'src') tgt_dir = os.path.join(self.cfg.outdir, 'tgt') self.ctx.pred_path = os.path.join(pred_dir, '%d.txt' % ID) self.ctx.src_path = os.path.join(src_dir, '%d.txt' % ID) self.ctx.tgt_path = os.path.join(tgt_dir, '%d.txt' % ID) os.makedirs(pred_dir, exist_ok=True) os.makedirs(src_dir, exist_ok=True) os.makedirs(tgt_dir, exist_ok=True) self.pred_file = codecs.open(self.ctx.pred_path, 'w', 'utf-8') self.src_file = codecs.open(self.ctx.src_path, 'w', 'utf-8') self.tgt_file = codecs.open(self.ctx.tgt_path, 'w', 'utf-8') self.ctx.model.update_client_id(ID) def update_contrast_monitor(self, contrast_monitor): self.ctx.contrast_monitor = contrast_monitor def get_model_grads(self, filter_keywords=None): if filter_keywords is None: filter_keywords = self.ctx.cfg.personalization.local_param grads = {} for n, p2 in self.ctx.model.state_dict().items(): if filter_by_specified_keywords(n, filter_keywords): # preserve grads[n] = p2 - self.ctx.init_params[n] return grads def parse_data(self, data): init_dict = dict() if isinstance(data, dict): all_split = ['train', 'val', 'test'] if not \ self.cfg.model.use_contrastive_loss else \ ['train_raw', 'train_contrast', 'val', 'test'] for split in all_split: init_dict['{}_data'.format(split)] = None init_dict['{}_loader'.format(split)] = None init_dict['num_{}_data'.format(split)] = 0 init_dict['{}_encoded'.format(split)] = None init_dict['{}_examples'.format(split)] = None if data.get(split, None) is not None: if isinstance(data.get(split)['dataloader'], DataLoader): init_dict['{}_loader'.format(split)] = \ data.get(split)['dataloader'] init_dict['num_{}_data'.format(split)] = \ len(data.get(split)['dataloader'].dataset) init_dict['{}_encoded'.format(split)] = \ data.get(split)['encoded'] init_dict['{}_examples'.format(split)] = \ data.get(split)['examples'] if self.cfg.model.use_contrastive_loss and \ split == 'train_raw': init_dict['train_data'] = None init_dict['train_loader'] = \ data.get(split)['dataloader'] init_dict['num_train_data'] = \ len(data.get(split)['dataloader'].dataset) init_dict['train_encoded'] = \ data.get(split)['encoded'] init_dict['train_examples'] = \ data.get(split)['examples'] else: raise TypeError('Type {} is not supported.'.format( type(data.get(split)))) else: raise TypeError('Type of data should be dict.') return init_dict def setup_optimizer_and_scheduler(self, ctx): total_steps = getattr(ctx, f'num_total_{ctx.cur_mode}_batch', None) // ctx.cfg.grad.grad_accum_count * \ ctx.cfg.federate.total_round_num warmup_steps = int(ctx.cfg[ctx.cur_mode].scheduler.warmup_ratio * total_steps) optimizer = get_optimizer(ctx.model, **ctx.cfg[ctx.cur_mode].optimizer) scheduler = get_scheduler(optimizer, **ctx.cfg.train.scheduler, total_steps=total_steps, warmup_steps=warmup_steps) return optimizer, scheduler def _load_model(self, ctx): load_path = ctx.cfg.federate.atc_load_from global_ckpt_path = os.path.join(load_path, 'global_model.pt') client_ckpt_path = \ os.path.join(load_path, 'client', 'client_model_{}.pt'.format( self.ID)) if not os.path.exists(global_ckpt_path): global_dir = os.path.join(load_path, 'global') global_ckpt_path = \ os.path.join(global_dir, 'global_model_{}.pt'.format(self.ID)) if not os.path.exists(global_ckpt_path): raise RuntimeError( 'Checkpoint NOT found in \'{}\''.format(global_ckpt_path)) model_ckpt = ctx.model.state_dict() logger.info('Loading model from \'{}\''.format(global_ckpt_path)) global_ckpt = torch.load(global_ckpt_path, map_location='cpu')['model'] model_ckpt.update({ k: v for k, v in global_ckpt.items() if k in model_ckpt and v.size() == model_ckpt[k].size() }) if os.path.exists(client_ckpt_path): logger.info('Updating model from \'{}\''.format(client_ckpt_path)) client_ckpt = torch.load(client_ckpt_path, map_location='cpu')['model'] model_ckpt.update({ k: v for k, v in client_ckpt.items() if k in model_ckpt and v.size() == model_ckpt[k].size() }) ctx.model.load_state_dict(model_ckpt) def _save_model(self, ctx): if len(ctx.cfg.personalization.local_param) > 0: model_ckpt = OrderedDict({ k: v for k, v in ctx.model.state_dict().items() if re.search('|'.join(ctx.cfg.personalization.local_param), k) is not None }) ckpt = { 'model': model_ckpt, 'epoch': ctx.cur_epoch_i + 1, 'batch': ctx.cur_batch_i + 1, } save_dir = os.path.join(ctx.cfg.federate.save_to, 'client') os.makedirs(save_dir, exist_ok=True) ckpt_path = os.path.join(save_dir, 'client_model_{}.pt'.format(self.ID)) torch.save(ckpt, ckpt_path) def _remove_special_tokens(self, sent): return sent.replace('[CLS]', '').replace('[SEP]', '').\ replace('[PAD]', '').replace('[unused0]', '').\ replace('[unused3]', '').replace('[unused1]', ''). \ replace(r' +', ' ').replace(' [unused2] ', '').\ replace('[unused2]', '').strip() @property def _in_contrast_prepare(self): return self.use_contrastive_loss and \ self.task != 'pretrain' and \ self.ctx.cur_split == 'train' and \ self.ctx.contrast_monitor.stat == 1 def train(self, target_data_split_name='train', hooks_set=None): hooks_set = hooks_set or self.hooks_in_train self.ctx.check_split(target_data_split_name) num_samples = self._run_routine(MODE.TRAIN, hooks_set, target_data_split_name) if not self.use_contrastive_loss: return \ num_samples, self.get_model_para(), self.get_model_grads(), \ self.ctx.eval_metrics return \ num_samples, self.get_model_para(), self.get_model_grads(), \ self.ctx.contrast_monitor, self.ctx.eval_metrics @lifecycle(LIFECYCLE.ROUTINE) def _run_routine(self, mode, hooks_set, dataset_name=None): if self.finish_eval: return self.ctx.num_samples raw_num_train_epoch, raw_num_train_batch = None, None if self._in_contrast_prepare: raw_num_train_epoch, raw_num_train_batch = \ self.ctx.num_train_epoch, self.ctx.num_train_batch batch_size = self.ctx.cfg.data.batch_size num_contrast_data = len(self.ctx.contrast_monitor.synth_tokens) self.ctx.num_train_epoch = 1 self.ctx.num_train_batch = \ num_contrast_data // batch_size + bool(num_contrast_data % batch_size) self.ctx.num_train_batch_last_epoch = self.ctx.num_train_batch self.ctx.num_total_train_batch = \ self.ctx.num_train_epoch * self.ctx.num_train_batch for hook in hooks_set['on_fit_start']: hook(self.ctx) self._run_epoch(hooks_set) for hook in hooks_set["on_fit_end"]: hook(self.ctx) if raw_num_train_epoch is not None and raw_num_train_batch is not None: self.ctx.num_train_epoch = raw_num_train_epoch self.ctx.num_train_batch = raw_num_train_batch self.ctx.num_train_batch_last_epoch = self.ctx.num_train_batch self.ctx.num_total_train_batch = \ self.ctx.num_train_epoch * self.ctx.num_train_batch return self.ctx.num_samples @lifecycle(LIFECYCLE.BATCH) def _run_batch(self, hooks_set): for batch_i in tqdm(range( getattr(self.ctx, f"num_{self.ctx.cur_split}_batch", None)), disable=not (self._in_contrast_prepare or self.ctx.cur_split == "test")): self.ctx.cur_batch_i = CtxVar(batch_i, LIFECYCLE.BATCH) for hook in hooks_set["on_batch_start"]: hook(self.ctx) for hook in hooks_set["on_batch_forward"]: hook(self.ctx) for hook in hooks_set["on_batch_backward"]: hook(self.ctx) for hook in hooks_set["on_batch_end"]: hook(self.ctx) # Break in the final epoch if self.ctx.cur_mode in [MODE.TRAIN, MODE.FINETUNE] and \ self.ctx.cur_epoch_i == getattr( self.ctx, f'num_{self.ctx.cur_mode}_epoch', None) - 1: if batch_i >= \ getattr(self.ctx, f'num_{self.ctx.cur_mode}_batch_last_epoch', None) - 1: break def _hook_on_fit_start_init(self, ctx): ctx.model.to(ctx.device) if ctx.cur_mode in [MODE.TRAIN, MODE.FINETUNE]: ctx.optimizer = ctx.get(f'{ctx.cur_mode}_optimizer', None) ctx.scheduler = ctx.get(f'{ctx.cur_mode}_scheduler', None) if ctx.optimizer is None or ctx.scheduler is None: ctx.optimizer, ctx.scheduler = \ self.setup_optimizer_and_scheduler(ctx) setattr(ctx, f'{ctx.cur_mode}_optimizer', ctx.optimizer) setattr(ctx, f'{ctx.cur_mode}_scheduler', ctx.scheduler) if ctx.cfg.federate.atc_load_from and self.load_ckpt: self._load_model(ctx) self.load_ckpt = False if ctx.cur_split == 'train' and ctx.cfg.federate.atc_load_from \ and self.load_ckpt: self._load_model(ctx) self.load_ckpt = False # prepare statistics ctx.loss_agg = CtxVar(AverageMeter(), LIFECYCLE.ROUTINE) ctx.loss_batch_total = CtxVar(0, LIFECYCLE.ROUTINE) ctx.loss_regular_total = CtxVar(0, LIFECYCLE.ROUTINE) ctx.num_samples = CtxVar(0, LIFECYCLE.ROUTINE) ctx.accum_steps = CtxVar(0, LIFECYCLE.ROUTINE) ctx.ys_true = CtxVar([], LIFECYCLE.ROUTINE) ctx.ys_pred = CtxVar([], LIFECYCLE.ROUTINE) ctx.squad_results = CtxVar([], LIFECYCLE.ROUTINE) ctx.newsqa_results = CtxVar([], LIFECYCLE.ROUTINE) if self.use_contrastive_loss: if self._in_contrast_prepare: ctx.train_loader = ctx.train_contrast_loader else: ctx.regular_loss_agg = CtxVar(AverageMeter(), LIFECYCLE.ROUTINE) ctx.contrastive_loss_agg = CtxVar(AverageMeter(), LIFECYCLE.ROUTINE) ctx.train_loader = ctx.train_raw_loader def _hook_on_batch_forward(self, ctx): if self.use_contrastive_loss: ctx.contrastive_loss_batch = CtxVar(None, LIFECYCLE.BATCH) if self.task == 'pretrain': token_ids = ctx.data_batch[self.pretrain_task]['token_ids'] attention_mask = \ ctx.data_batch[self.pretrain_task]['attention_mask'] labels = ctx.data_batch[self.pretrain_task]['labels'] example_indices = \ ctx.data_batch[self.pretrain_task]['example_indices'] outputs = ctx.model( input_ids=token_ids.to(ctx.device), attention_mask=attention_mask.to(ctx.device), labels=labels.to(ctx.device), pretrain_task=self.pretrain_task, example_indices=example_indices, ) ctx.batch_size = CtxVar(len(token_ids), LIFECYCLE.BATCH) ctx.loss_batch = CtxVar(outputs.loss, LIFECYCLE.BATCH) if self.pretrain_task == 'mlm': y_true = labels elif self.pretrain_task == 'denoise': y_true = labels[:, 1:] else: raise KeyError('Unsupported pretrain task: \'{}\''.format( self.pretrain_task)) count_idx = y_true.ne(-100) & y_true.ne(ctx.padding_idx) ctx.y_true = CtxVar(y_true[count_idx], LIFECYCLE.BATCH) ctx.y_pred = CtxVar( outputs.logits.argmax(dim=-1)[count_idx], LIFECYCLE.BATCH) else: token_ids = ctx.data_batch.get('token_ids', None) token_type_ids = ctx.data_batch.get('token_type_ids', None) attention_mask = ctx.data_batch.get('attention_mask', None) labels = ctx.data_batch.get('labels', None) start_positions = ctx.data_batch.get('start_positions', None) end_positions = ctx.data_batch.get('end_positions', None) example_indices = ctx.data_batch.get('example_indices', None) if self.task in {'imdb', 'agnews'}: outputs = ctx.model( input_ids=token_ids.to(ctx.device), token_type_ids=token_type_ids.to(ctx.device), attention_mask=attention_mask.to(ctx.device), labels=labels.to(ctx.device), contrast_monitor=ctx.contrast_monitor, in_contrast_prepare=self._in_contrast_prepare, example_indices=example_indices, ) if not self._in_contrast_prepare: ctx.batch_size = CtxVar(len(token_ids), LIFECYCLE.BATCH) ctx.loss_batch = CtxVar(outputs.loss, LIFECYCLE.BATCH) if self.use_contrastive_loss: ctx.regular_loss_batch = CtxVar( outputs.regular_loss, LIFECYCLE.BATCH) ctx.contrastive_loss_batch = CtxVar( outputs.contrastive_loss, LIFECYCLE.BATCH) ctx.y_true = CtxVar(labels, LIFECYCLE.BATCH) ctx.y_pred = CtxVar(outputs.logits.argmax(dim=-1), LIFECYCLE.BATCH) elif self.task in {'squad', 'newsqa'}: outputs = ctx.model( input_ids=token_ids.to(ctx.device), token_type_ids=token_type_ids.to(ctx.device), attention_mask=attention_mask.to(ctx.device), start_positions=start_positions.to(ctx.device), end_positions=end_positions.to(ctx.device), contrast_monitor=ctx.contrast_monitor, in_contrast_prepare=self._in_contrast_prepare, example_indices=example_indices, ) if not self._in_contrast_prepare: for i, example_idx in enumerate(example_indices): encoded_input = ctx.get('{}_encoded'.format( ctx.cur_split))[example_idx.item()] unique_id = int(encoded_input.unique_id) start_logits = \ outputs.logits[0][i].detach().cpu().tolist() end_logits = \ outputs.logits[1][i].detach().cpu().tolist() if ctx.cur_split != 'train': if self.task == 'squad': ctx.squad_results.append( SquadResult(unique_id, start_logits, end_logits)) elif self.task == 'newsqa': ctx.newsqa_results.append( NewsQAResult(unique_id, start_logits, end_logits)) ctx.batch_size = CtxVar(len(token_ids), LIFECYCLE.BATCH) ctx.loss_batch = CtxVar(outputs.loss, LIFECYCLE.BATCH) if self.use_contrastive_loss: ctx.regular_loss_batch = CtxVar( outputs.regular_loss, LIFECYCLE.BATCH) ctx.contrastive_loss_batch = CtxVar( outputs.contrastive_loss, LIFECYCLE.BATCH) ctx.y_true = CtxVar( torch.cat([start_positions, end_positions]), LIFECYCLE.BATCH) ctx.y_pred = CtxVar( torch.cat( [out.argmax(dim=-1) for out in outputs.logits]), LIFECYCLE.BATCH) elif self.task in {'cnndm', 'msqg'}: if ctx.cur_split != 'test': outputs = ctx.model( input_ids=token_ids.to(ctx.device), token_type_ids=token_type_ids.to(ctx.device), attention_mask=attention_mask.to(ctx.device), labels=labels.to(ctx.device), contrast_monitor=ctx.contrast_monitor, in_contrast_prepare=self._in_contrast_prepare, example_indices=example_indices, ) if not self._in_contrast_prepare: ctx.batch_size = CtxVar(len(labels), LIFECYCLE.BATCH) ctx.loss_batch = CtxVar(outputs.loss, LIFECYCLE.BATCH) if self.use_contrastive_loss: ctx.regular_loss_batch = CtxVar( outputs.regular_loss, LIFECYCLE.BATCH) ctx.contrastive_loss_batch = CtxVar( outputs.contrastive_loss, LIFECYCLE.BATCH) y_pred = outputs.logits.argmax(dim=-1) y_true = labels[:, 1:] non_padding_idx = y_true.ne(ctx.padding_idx) ctx.y_true = CtxVar(y_true[non_padding_idx], LIFECYCLE.BATCH) ctx.y_pred = CtxVar(y_pred[non_padding_idx], LIFECYCLE.BATCH) else: outputs = ctx.model.generate( input_ids=token_ids.to(ctx.device), token_type_ids=token_type_ids.to(ctx.device), attention_mask=attention_mask.to(ctx.device), ) # save to file out_str = ctx.tokenizer.batch_decode(outputs) src_str = ctx.tokenizer.batch_decode(token_ids) ref_str = ctx.tokenizer.batch_decode(labels) for out, src, ref in zip(out_str, src_str, ref_str): out = self._remove_special_tokens(out) src = self._remove_special_tokens(src) ref = self._remove_special_tokens(ref) self.pred_file.write(out + '\n') self.src_file.write(src + '\n') self.tgt_file.write(ref + '\n') self.pred_file.flush() self.src_file.flush() self.tgt_file.flush() ctx.batch_size = CtxVar(len(labels), LIFECYCLE.BATCH) ctx.y_pred = CtxVar(outputs, LIFECYCLE.BATCH) ctx.y_true = CtxVar(labels[:, 1:], LIFECYCLE.BATCH) return if self._in_contrast_prepare: ctx.batch_size = CtxVar(0, LIFECYCLE.BATCH) dec_out, dec_hidden, example_indices = \ outputs.logits, outputs.hidden_states, outputs.example_indices if len(example_indices) > 0: for ex, out in zip(example_indices, dec_out.detach().cpu()): ctx.contrast_monitor.update_dec_out(out, k=ex.item()) for ex, hids in zip(example_indices, dec_hidden.detach().cpu()): ctx.contrast_monitor.update_dec_hidden(hids, k=ex.item()) else: ctx.loss_agg.update(ctx.loss_batch.detach().item(), ctx.batch_size) if self.use_contrastive_loss: if ctx.regular_loss_batch is not None and \ ctx.contrastive_loss_batch is not None: ctx.regular_loss_agg.update( ctx.regular_loss_batch.detach().item(), ctx.batch_size) ctx.contrastive_loss_agg.update( ctx.contrastive_loss_batch.detach().item(), ctx.batch_size) def _hook_on_batch_forward_regularizer(self, ctx): if self._in_contrast_prepare: return super()._hook_on_batch_forward_regularizer(ctx) def _hook_on_batch_backward(self, ctx): if self._in_contrast_prepare: return cur_step = (ctx.cur_batch_i + 1) // ctx.grad_accum_count ctx.accum_steps += 1 ctx.loss_task /= ctx.grad_accum_count ctx.loss_task.backward() if ctx.accum_steps == ctx.grad_accum_count: if ctx.grad_clip > 0: torch.nn.utils.clip_grad_norm_(ctx.model.parameters(), ctx.grad_clip) ctx.optimizer.step() ctx.scheduler.step() ctx.optimizer.zero_grad() ctx.accum_steps = CtxVar(0, LIFECYCLE.ROUTINE) total_epoch = getattr(ctx, f'num_{ctx.cur_mode}_epoch', None) total_batch = getattr(ctx, f'num_{ctx.cur_mode}_batch', None) if \ ctx.cur_epoch_i + 1 < total_epoch else \ getattr(ctx, f'num_{ctx.cur_mode}_batch_last_epoch', None) if ctx.accum_steps == 0: if cur_step > 1 and (cur_step % ctx.cfg.trainer.disp_freq == 0 or ctx.cur_batch_i + 1 == total_batch): y_true = ctx.y_true.detach().cpu().numpy() y_pred = ctx.y_pred.detach().cpu().numpy() if y_true.ndim == 1: y_true = np.expand_dims(y_true, axis=-1) if y_pred.ndim == 1: y_pred = np.expand_dims(y_pred, axis=-1) cur_acc = eval_acc(y_true, y_pred) log_str = 'Epoch: [{}/{}][{}/{}]\t' \ 'LR: {:.2e}\t' \ 'Acc: {:.4f}\t' \ 'Loss: {loss.val:.4f} ({loss.avg:.4f})'\ .format(ctx.cur_epoch_i + 1, total_epoch, cur_step, total_batch // ctx.grad_accum_count, ctx.scheduler.get_last_lr()[0], cur_acc, loss=ctx.loss_agg) if self.use_contrastive_loss: log_str += \ '\tRegular loss: {loss.val:.4f} ' \ '({loss.avg:.4f})'.format(loss=ctx.regular_loss_agg) log_str += \ '\tContrastive loss: {loss.val:.4f} ' \ '({loss.avg:.4f})'.format( loss=ctx.contrastive_loss_agg) if self.task == 'pretrain': log_str += '\t({})'.format(self.pretrain_task) logger.info(log_str) if ctx.cur_batch_i + 1 == total_batch and ctx.cfg.federate.save_to: self._save_model(ctx) def _hook_on_batch_end(self, ctx): if self._in_contrast_prepare: return # update statistics ctx.num_samples += ctx.batch_size ctx.loss_batch_total += ctx.get( "loss_batch", torch.tensor(0.)).item() * ctx.batch_size ctx.loss_regular_total += float(ctx.get("loss_regular", 0.)) # cache label for evaluate if self.task in {'pretrain', 'squad', 'newsqa', 'cnndm', 'msqg'}: ctx.ys_true = CtxVar([ctx.y_true.detach().cpu().numpy()], LIFECYCLE.ROUTINE) ctx.ys_pred = CtxVar([ctx.y_pred.detach().cpu().numpy()], LIFECYCLE.ROUTINE) else: ctx.ys_true.append(ctx.y_true.detach().cpu().numpy()) ctx.ys_pred.append(ctx.y_pred.detach().cpu().numpy()) def _hook_on_fit_end(self, ctx): if self.use_contrastive_loss and self.task != 'pretrain' and \ ctx.cur_split == 'train': ctx.contrast_monitor.update_stat(ctx.contrast_monitor.stat + 1) return if ctx.cur_split != 'train': ctx.ys_true = CtxVar(np.concatenate(ctx.ys_true), LIFECYCLE.ROUTINE) ctx.ys_pred = CtxVar(np.concatenate(ctx.ys_pred), LIFECYCLE.ROUTINE) results = self.metric_calculator.eval(ctx) setattr(ctx, 'eval_metrics', results) if ctx.cur_split == 'test' and not self.finish_eval: if self.pred_file is not None: self.pred_file.close() if self.src_file is not None: self.src_file.close() if self.tgt_file is not None: self.tgt_file.close() self.finish_eval = True