Adding new argument max_new_tokens for generate. (#11476)

* Adding new argument `max_new_tokens` for generate.

This is a proposal to add a new argument `max_new_tokens` to `generate`.
This include a `MaxNewTokensCriteria` that enables callers that don't
know about the token length ahead (like pipelines callers) to manage
more easily the length of their generated output.

* Adding a test for the user warning when both`max_length` and
`max_new_tokens` are used together.

* Removed redundant `no_grad`.
This commit is contained in:
Nicolas Patry
2021-05-27 14:22:58 +02:00
committed by GitHub
parent 2dd6fb2585
commit 80d712fac6
4 changed files with 86 additions and 5 deletions

View File

@@ -12,6 +12,7 @@ if is_torch_available():
from transformers.generation_stopping_criteria import (
MaxLengthCriteria,
MaxNewTokensCriteria,
MaxTimeCriteria,
StoppingCriteriaList,
validate_stopping_criteria,
@@ -58,6 +59,21 @@ class StoppingCriteriaTestCase(unittest.TestCase):
input_ids, scores = self._get_tensors(10)
self.assertTrue(criteria(input_ids, scores))
def test_max_new_tokens_criteria(self):
criteria = MaxNewTokensCriteria(start_length=5, max_new_tokens=5)
input_ids, scores = self._get_tensors(5)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(9)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(10)
self.assertTrue(criteria(input_ids, scores))
criteria_list = StoppingCriteriaList([criteria])
self.assertEqual(criteria_list.max_length, 10)
def test_max_time_criteria(self):
input_ids, scores = self._get_tensors(5)