import torch import torch.nn as nn import torch.nn.functional as F class LabelSmoothingLoss(nn.Module): """ With label smoothing, KL-divergence between q_{smoothed ground truth prob.}(w) and p_{prob. computed by model}(w) is minimized. """ def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100): assert 0.0 < label_smoothing <= 1.0 self.padding_idx = ignore_index super(LabelSmoothingLoss, self).__init__() smoothing_value = label_smoothing / (tgt_vocab_size - 2) one_hot = torch.full((tgt_vocab_size, ), smoothing_value) one_hot[self.padding_idx] = 0 self.register_buffer('one_hot', one_hot.unsqueeze(0)) self.confidence = 1.0 - label_smoothing def forward(self, output, target): """ output (FloatTensor): batch_size x n_classes target (LongTensor): batch_size """ model_prob = self.one_hot.repeat(target.size(0), 1) model_prob.scatter_(1, target.unsqueeze(1), self.confidence) model_prob.masked_fill_((target == self.padding_idx).unsqueeze(1), 0) return F.kl_div(output, model_prob, reduction='sum')