Fix generation min length (#16206)

* up

* fix min lengths
This commit is contained in:
Patrick von Platen
2022-03-16 18:49:23 +01:00
committed by GitHub
parent 667b823b89
commit 2410d0f8ed
2 changed files with 5 additions and 2 deletions

View File

@@ -741,7 +741,7 @@ class GenerationMixin:
) )
if bad_words_ids is not None: if bad_words_ids is not None:
processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
if min_length is not None and eos_token_id is not None and min_length > -1: if min_length is not None and eos_token_id is not None and min_length > 0:
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id)) processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
if prefix_allowed_tokens_fn is not None: if prefix_allowed_tokens_fn is not None:
processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups)) processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups))

View File

@@ -1949,11 +1949,14 @@ class GenerationIntegrationTests(unittest.TestCase):
def test_custom_logits_processor(self): def test_custom_logits_processor(self):
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random", min_length=1).to(
torch_device
)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
logits_processor = LogitsProcessorList() logits_processor = LogitsProcessorList()
logits_processor.append(MinLengthLogitsProcessor(min_length=10, eos_token_id=0)) logits_processor.append(MinLengthLogitsProcessor(min_length=10, eos_token_id=0))
# it should not be allowed to both define `min_length` via config and `logits_processor` list
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
bart_model.generate(input_ids, logits_processor=logits_processor) bart_model.generate(input_ids, logits_processor=logits_processor)