32 lines
1.2 KiB
Python
32 lines
1.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class LabelSmoothingLoss(nn.Module):
|
|
"""
|
|
With label smoothing,
|
|
KL-divergence between q_{smoothed ground truth prob.}(w)
|
|
and p_{prob. computed by model}(w) is minimized.
|
|
"""
|
|
def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100):
|
|
assert 0.0 < label_smoothing <= 1.0
|
|
self.padding_idx = ignore_index
|
|
super(LabelSmoothingLoss, self).__init__()
|
|
|
|
smoothing_value = label_smoothing / (tgt_vocab_size - 2)
|
|
one_hot = torch.full((tgt_vocab_size, ), smoothing_value)
|
|
one_hot[self.padding_idx] = 0
|
|
self.register_buffer('one_hot', one_hot.unsqueeze(0))
|
|
self.confidence = 1.0 - label_smoothing
|
|
|
|
def forward(self, output, target):
|
|
"""
|
|
output (FloatTensor): batch_size x n_classes
|
|
target (LongTensor): batch_size
|
|
"""
|
|
model_prob = self.one_hot.repeat(target.size(0), 1)
|
|
model_prob.scatter_(1, target.unsqueeze(1), self.confidence)
|
|
model_prob.masked_fill_((target == self.padding_idx).unsqueeze(1), 0)
|
|
return F.kl_div(output, model_prob, reduction='sum')
|