Adding new argument max_new_tokens for generate. (#11476)
* Adding new argument `max_new_tokens` for generate. This is a proposal to add a new argument `max_new_tokens` to `generate`. This include a `MaxNewTokensCriteria` that enables callers that don't know about the token length ahead (like pipelines callers) to manage more easily the length of their generated output. * Adding a test for the user warning when both`max_length` and `max_new_tokens` are used together. * Removed redundant `no_grad`.
This commit is contained in:
@@ -12,6 +12,7 @@ if is_torch_available():
|
||||
|
||||
from transformers.generation_stopping_criteria import (
|
||||
MaxLengthCriteria,
|
||||
MaxNewTokensCriteria,
|
||||
MaxTimeCriteria,
|
||||
StoppingCriteriaList,
|
||||
validate_stopping_criteria,
|
||||
@@ -58,6 +59,21 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
||||
input_ids, scores = self._get_tensors(10)
|
||||
self.assertTrue(criteria(input_ids, scores))
|
||||
|
||||
def test_max_new_tokens_criteria(self):
|
||||
criteria = MaxNewTokensCriteria(start_length=5, max_new_tokens=5)
|
||||
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
|
||||
input_ids, scores = self._get_tensors(9)
|
||||
self.assertFalse(criteria(input_ids, scores))
|
||||
|
||||
input_ids, scores = self._get_tensors(10)
|
||||
self.assertTrue(criteria(input_ids, scores))
|
||||
|
||||
criteria_list = StoppingCriteriaList([criteria])
|
||||
self.assertEqual(criteria_list.max_length, 10)
|
||||
|
||||
def test_max_time_criteria(self):
|
||||
input_ids, scores = self._get_tensors(5)
|
||||
|
||||
|
||||
@@ -1615,3 +1615,26 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
# BeamSearchScorer max_length should not influence "real" max_length
|
||||
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())
|
||||
|
||||
def test_max_new_tokens(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
self.assertEqual(list(input_ids.shape), [1, 15])
|
||||
|
||||
# Encoder decoder call
|
||||
max_new_tokens = 3
|
||||
outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens)
|
||||
# 1 BOS + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 4])
|
||||
|
||||
# Decoder only call
|
||||
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens)
|
||||
# 15 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 18])
|
||||
|
||||
# max_new_tokens and max_length serve the same purpose and should not be used together.
|
||||
with self.assertWarns(UserWarning):
|
||||
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
|
||||
|
||||
Reference in New Issue
Block a user