From 03e363f9ae0a2a75bcae7fd1247a1445342cbd7b Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 9 Sep 2020 01:08:36 -0700 Subject: [PATCH] [generation] consistently add eos tokens (#6982) Currently beam search returns inconsistent outputs - if hypos have different lengths we get eos, if they are the same - we don't. This PR makes the output consistent. Also why not also replace: ``` if sent_lengths[i] < max_length: decoded[i, sent_lengths[i]] = eos_token_id ``` with: ``` decoded[i, sent_lengths[i]] = eos_token_id ``` Shouldn't eos always be there? If the data gets truncated, the caller needs to user a larger `max_length`. Please correct me if my logic is flawed. --- src/transformers/generation_utils.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 053b13a343..a961785188 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -841,21 +841,19 @@ class GenerationMixin: sent_lengths[effective_batch_idx] = len(best_hyp) best.append(best_hyp) - # shorter batches are padded + # prepare for adding eos + sent_max_len = min(sent_lengths.max().item() + 1, max_length) + decoded = input_ids.new(output_batch_size, sent_max_len) + # shorter batches are padded if needed if sent_lengths.min().item() != sent_lengths.max().item(): - assert pad_token_id is not None, "`Pad_token_id` has to be defined" - sent_max_len = min(sent_lengths.max().item() + 1, max_length) - decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id) + assert pad_token_id is not None, "`pad_token_id` has to be defined" + decoded.fill_(pad_token_id) - # fill with hypothesis and eos_token_id if necessary - for i, hypo in enumerate(best): - decoded[i, : sent_lengths[i]] = hypo - if sent_lengths[i] < max_length: - decoded[i, sent_lengths[i]] = eos_token_id - else: - # none of the hypotheses have an eos_token - assert (len(hypo) == max_length for hypo in best) - decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device) + # fill with hypotheses and eos_token_id if the latter fits in + for i, hypo in enumerate(best): + decoded[i, : sent_lengths[i]] = hypo + if sent_lengths[i] < max_length: + decoded[i, sent_lengths[i]] = eos_token_id return decoded