Adding new argument max_new_tokens for generate. (#11476)

* Adding new argument `max_new_tokens` for generate.

This is a proposal to add a new argument `max_new_tokens` to `generate`.
This include a `MaxNewTokensCriteria` that enables callers that don't
know about the token length ahead (like pipelines callers) to manage
more easily the length of their generated output.

* Adding a test for the user warning when both`max_length` and
`max_new_tokens` are used together.

* Removed redundant `no_grad`.
This commit is contained in:
Nicolas Patry
2021-05-27 14:22:58 +02:00
committed by GitHub
parent 2dd6fb2585
commit 80d712fac6
4 changed files with 86 additions and 5 deletions

View File

@@ -1615,3 +1615,26 @@ class GenerationIntegrationTests(unittest.TestCase):
# BeamSearchScorer max_length should not influence "real" max_length
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())
def test_max_new_tokens(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
self.assertEqual(list(input_ids.shape), [1, 15])
# Encoder decoder call
max_new_tokens = 3
outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens)
# 1 BOS + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 4])
# Decoder only call
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens)
# 15 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 18])
# max_new_tokens and max_length serve the same purpose and should not be used together.
with self.assertWarns(UserWarning):
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)