From 9bed35544905d735566dccf88d2c65f1c211c675 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Sat, 8 Aug 2020 13:51:12 +0530 Subject: [PATCH] [s2s] fix label_smoothed_nll_loss (#6344) --- examples/seq2seq/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 1c13c0aa28..20440e3379 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -29,17 +29,15 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): pad_mask = target.eq(ignore_index) nll_loss.masked_fill_(pad_mask, 0.0) smooth_loss.masked_fill_(pad_mask, 0.0) - bs = pad_mask.long().sum() else: nll_loss = nll_loss.squeeze(-1) smooth_loss = smooth_loss.squeeze(-1) - bs = lprobs.shape[0] nll_loss = nll_loss.sum() # mean()? Scared to break other math. smooth_loss = smooth_loss.sum() eps_i = epsilon / lprobs.size(-1) loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss - return loss / bs, nll_loss / bs + return loss, nll_loss def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):