[examples/seq2seq]: add --label_smoothing option (#5919)

This commit is contained in:
Sam Shleifer
2020-07-21 16:51:39 -04:00
committed by GitHub
parent 95d1962b9c
commit 5b193b39b0
7 changed files with 132 additions and 46 deletions

View File

@@ -19,6 +19,29 @@ from torch.utils.data import Dataset, Sampler
from transformers import BartTokenizer
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
"""From fairseq"""
if target.dim() == lprobs.dim() - 1:
target = target.unsqueeze(-1)
nll_loss = -lprobs.gather(dim=-1, index=target)
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
if ignore_index is not None:
pad_mask = target.eq(ignore_index)
nll_loss.masked_fill_(pad_mask, 0.0)
smooth_loss.masked_fill_(pad_mask, 0.0)
bs = pad_mask.long().sum()
else:
nll_loss = nll_loss.squeeze(-1)
smooth_loss = smooth_loss.squeeze(-1)
bs = lprobs.shape[0]
nll_loss = nll_loss.sum() # mean()? Scared to break other math.
smooth_loss = smooth_loss.sum()
eps_i = epsilon / lprobs.size(-1)
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
return loss / bs, nll_loss / bs
def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
return tokenizer(
@@ -144,8 +167,8 @@ class MBartDataset(Seq2SeqDataset):
assert source_line, f"empty source line for index {index}"
assert tgt_line, f"empty tgt line for index {index}"
return {
"tgt_texts": source_line,
"src_texts": tgt_line,
"tgt_texts": tgt_line,
"src_texts": source_line,
}
def collate_fn(self, batch) -> Dict[str, torch.Tensor]: