From 2410d0f8edd954a6bfefcb426fbb43c63deca2c6 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 16 Mar 2022 18:49:23 +0100 Subject: [PATCH] Fix generation min length (#16206) * up * fix min lengths --- src/transformers/generation_utils.py | 2 +- tests/generation/test_generation_utils.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 62f37ad624..de99ac3ed1 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -741,7 +741,7 @@ class GenerationMixin: ) if bad_words_ids is not None: processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) - if min_length is not None and eos_token_id is not None and min_length > -1: + if min_length is not None and eos_token_id is not None and min_length > 0: processors.append(MinLengthLogitsProcessor(min_length, eos_token_id)) if prefix_allowed_tokens_fn is not None: processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups)) diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index 818cbfe17e..031061041e 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -1949,11 +1949,14 @@ class GenerationIntegrationTests(unittest.TestCase): def test_custom_logits_processor(self): bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" - bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) + bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random", min_length=1).to( + torch_device + ) input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) logits_processor = LogitsProcessorList() logits_processor.append(MinLengthLogitsProcessor(min_length=10, eos_token_id=0)) + # it should not be allowed to both define `min_length` via config and `logits_processor` list with self.assertRaises(ValueError): bart_model.generate(input_ids, logits_processor=logits_processor)