From 80d712fac6ccae308a2f408ebbc0c4d8c482d509 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 27 May 2021 14:22:58 +0200 Subject: [PATCH] Adding new argument `max_new_tokens` for generate. (#11476) * Adding new argument `max_new_tokens` for generate. This is a proposal to add a new argument `max_new_tokens` to `generate`. This include a `MaxNewTokensCriteria` that enables callers that don't know about the token length ahead (like pipelines callers) to manage more easily the length of their generated output. * Adding a test for the user warning when both`max_length` and `max_new_tokens` are used together. * Removed redundant `no_grad`. --- .../generation_stopping_criteria.py | 25 +++++++++++++++++ src/transformers/generation_utils.py | 27 +++++++++++++++---- tests/test_generation_stopping_criteria.py | 16 +++++++++++ tests/test_generation_utils.py | 23 ++++++++++++++++ 4 files changed, 86 insertions(+), 5 deletions(-) diff --git a/src/transformers/generation_stopping_criteria.py b/src/transformers/generation_stopping_criteria.py index 65fef72464..112acdcb6d 100644 --- a/src/transformers/generation_stopping_criteria.py +++ b/src/transformers/generation_stopping_criteria.py @@ -57,6 +57,29 @@ class MaxLengthCriteria(StoppingCriteria): return input_ids.shape[-1] >= self.max_length +class MaxNewTokensCriteria(StoppingCriteria): + """ + This class can be used to stop generation whenever the generated number of tokens exceeds :obj:`max_new_tokens`. + Keep in mind for decoder-only type of transformers, this will **not** include the initial prompted tokens. This is + very close to :obj:`MaxLengthCriteria` but ignores the number of initial tokens. + + Args: + start_length (:obj:`int`): + The number of initial tokens. + max_new_tokens (:obj:`int`): + The maximum number of tokens to generate. + """ + + def __init__(self, start_length: int, max_new_tokens: int): + self.start_length = start_length + self.max_new_tokens = max_new_tokens + self.max_length = start_length + max_new_tokens + + @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + return input_ids.shape[-1] >= self.max_length + + class MaxTimeCriteria(StoppingCriteria): """ This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the @@ -89,6 +112,8 @@ class StoppingCriteriaList(list): for stopping_criterium in self: if isinstance(stopping_criterium, MaxLengthCriteria): return stopping_criterium.max_length + elif isinstance(stopping_criterium, MaxNewTokensCriteria): + return stopping_criterium.max_length return None diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index cb04ff3377..bd3750ec43 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -42,6 +42,7 @@ from .generation_logits_process import ( ) from .generation_stopping_criteria import ( MaxLengthCriteria, + MaxNewTokensCriteria, MaxTimeCriteria, StoppingCriteriaList, validate_stopping_criteria, @@ -628,15 +629,15 @@ class GenerationMixin: return processors def _get_stopping_criteria( - self, - max_length: Optional[int], - max_time: Optional[float], + self, max_length: Optional[int], max_time: Optional[float], max_new_tokens: Optional[int], start_length: int ) -> StoppingCriteriaList: stopping_criteria = StoppingCriteriaList() if max_length is not None: stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) if max_time is not None: stopping_criteria.append(MaxTimeCriteria(max_time=max_time)) + if max_new_tokens is not None: + stopping_criteria.append(MaxNewTokensCriteria(start_length=start_length, max_new_tokens=max_new_tokens)) return stopping_criteria @torch.no_grad() @@ -661,6 +662,7 @@ class GenerationMixin: encoder_no_repeat_ngram_size: Optional[int] = None, num_return_sequences: Optional[int] = None, max_time: Optional[float] = None, + max_new_tokens: Optional[int] = None, decoder_start_token_id: Optional[int] = None, use_cache: Optional[bool] = None, num_beam_groups: Optional[int] = None, @@ -692,8 +694,11 @@ class GenerationMixin: input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty :obj:`torch.LongTensor` of shape :obj:`(1,)`. - max_length (:obj:`int`, `optional`, defaults to 20): + max_length (:obj:`int`, `optional`, defaults to :obj:`model.config.max_length`): The maximum length of the sequence to be generated. + max_new_tokens (:obj:`int`, `optional`, defaults to None): + The maximum numbers of tokens to generate, ignore the current number of tokens. Use either + :obj:`max_new_tokens` or :obj:`max_length` but not both, they serve the same purpose. min_length (:obj:`int`, `optional`, defaults to 10): The minimum length of the sequence to be generated. do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`): @@ -861,6 +866,15 @@ class GenerationMixin: """ # set init values + if max_length is None and max_new_tokens is None: + # Both are None, default + max_length = self.config.max_length + elif max_length is not None and max_new_tokens is not None: + # Both are set, this is odd, raise a warning + warnings.warn( + "Both `max_length` and `max_new_tokens` have been set but they serve the same purpose.", UserWarning + ) + max_length = max_length if max_length is not None else self.config.max_length num_beams = num_beams if num_beams is not None else self.config.num_beams num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups @@ -960,7 +974,10 @@ class GenerationMixin: remove_invalid_values=remove_invalid_values, ) - stopping_criteria = self._get_stopping_criteria(max_length=max_length, max_time=max_time) + cur_len = input_ids.shape[-1] + stopping_criteria = self._get_stopping_criteria( + max_length=max_length, max_time=max_time, max_new_tokens=max_new_tokens, start_length=cur_len + ) if is_greedy_gen_mode: if num_return_sequences > 1: diff --git a/tests/test_generation_stopping_criteria.py b/tests/test_generation_stopping_criteria.py index 995ea97736..d3de2c56da 100644 --- a/tests/test_generation_stopping_criteria.py +++ b/tests/test_generation_stopping_criteria.py @@ -12,6 +12,7 @@ if is_torch_available(): from transformers.generation_stopping_criteria import ( MaxLengthCriteria, + MaxNewTokensCriteria, MaxTimeCriteria, StoppingCriteriaList, validate_stopping_criteria, @@ -58,6 +59,21 @@ class StoppingCriteriaTestCase(unittest.TestCase): input_ids, scores = self._get_tensors(10) self.assertTrue(criteria(input_ids, scores)) + def test_max_new_tokens_criteria(self): + criteria = MaxNewTokensCriteria(start_length=5, max_new_tokens=5) + + input_ids, scores = self._get_tensors(5) + self.assertFalse(criteria(input_ids, scores)) + + input_ids, scores = self._get_tensors(9) + self.assertFalse(criteria(input_ids, scores)) + + input_ids, scores = self._get_tensors(10) + self.assertTrue(criteria(input_ids, scores)) + + criteria_list = StoppingCriteriaList([criteria]) + self.assertEqual(criteria_list.max_length, 10) + def test_max_time_criteria(self): input_ids, scores = self._get_tensors(5) diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index 1134674a80..289fa4882c 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -1615,3 +1615,26 @@ class GenerationIntegrationTests(unittest.TestCase): # BeamSearchScorer max_length should not influence "real" max_length self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist()) + + def test_max_new_tokens(self): + article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" + bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random") + bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device) + input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device) + + self.assertEqual(list(input_ids.shape), [1, 15]) + + # Encoder decoder call + max_new_tokens = 3 + outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens) + # 1 BOS + 3 new tokens + self.assertEqual(list(outputs.shape), [1, 4]) + + # Decoder only call + outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens) + # 15 + 3 new tokens + self.assertEqual(list(outputs.shape), [1, 18]) + + # max_new_tokens and max_length serve the same purpose and should not be used together. + with self.assertWarns(UserWarning): + outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)