Adding new argument max_new_tokens for generate. (#11476)
* Adding new argument `max_new_tokens` for generate. This is a proposal to add a new argument `max_new_tokens` to `generate`. This include a `MaxNewTokensCriteria` that enables callers that don't know about the token length ahead (like pipelines callers) to manage more easily the length of their generated output. * Adding a test for the user warning when both`max_length` and `max_new_tokens` are used together. * Removed redundant `no_grad`.
This commit is contained in:
@@ -57,6 +57,29 @@ class MaxLengthCriteria(StoppingCriteria):
|
|||||||
return input_ids.shape[-1] >= self.max_length
|
return input_ids.shape[-1] >= self.max_length
|
||||||
|
|
||||||
|
|
||||||
|
class MaxNewTokensCriteria(StoppingCriteria):
|
||||||
|
"""
|
||||||
|
This class can be used to stop generation whenever the generated number of tokens exceeds :obj:`max_new_tokens`.
|
||||||
|
Keep in mind for decoder-only type of transformers, this will **not** include the initial prompted tokens. This is
|
||||||
|
very close to :obj:`MaxLengthCriteria` but ignores the number of initial tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_length (:obj:`int`):
|
||||||
|
The number of initial tokens.
|
||||||
|
max_new_tokens (:obj:`int`):
|
||||||
|
The maximum number of tokens to generate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, start_length: int, max_new_tokens: int):
|
||||||
|
self.start_length = start_length
|
||||||
|
self.max_new_tokens = max_new_tokens
|
||||||
|
self.max_length = start_length + max_new_tokens
|
||||||
|
|
||||||
|
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||||
|
return input_ids.shape[-1] >= self.max_length
|
||||||
|
|
||||||
|
|
||||||
class MaxTimeCriteria(StoppingCriteria):
|
class MaxTimeCriteria(StoppingCriteria):
|
||||||
"""
|
"""
|
||||||
This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the
|
This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the
|
||||||
@@ -89,6 +112,8 @@ class StoppingCriteriaList(list):
|
|||||||
for stopping_criterium in self:
|
for stopping_criterium in self:
|
||||||
if isinstance(stopping_criterium, MaxLengthCriteria):
|
if isinstance(stopping_criterium, MaxLengthCriteria):
|
||||||
return stopping_criterium.max_length
|
return stopping_criterium.max_length
|
||||||
|
elif isinstance(stopping_criterium, MaxNewTokensCriteria):
|
||||||
|
return stopping_criterium.max_length
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ from .generation_logits_process import (
|
|||||||
)
|
)
|
||||||
from .generation_stopping_criteria import (
|
from .generation_stopping_criteria import (
|
||||||
MaxLengthCriteria,
|
MaxLengthCriteria,
|
||||||
|
MaxNewTokensCriteria,
|
||||||
MaxTimeCriteria,
|
MaxTimeCriteria,
|
||||||
StoppingCriteriaList,
|
StoppingCriteriaList,
|
||||||
validate_stopping_criteria,
|
validate_stopping_criteria,
|
||||||
@@ -628,15 +629,15 @@ class GenerationMixin:
|
|||||||
return processors
|
return processors
|
||||||
|
|
||||||
def _get_stopping_criteria(
|
def _get_stopping_criteria(
|
||||||
self,
|
self, max_length: Optional[int], max_time: Optional[float], max_new_tokens: Optional[int], start_length: int
|
||||||
max_length: Optional[int],
|
|
||||||
max_time: Optional[float],
|
|
||||||
) -> StoppingCriteriaList:
|
) -> StoppingCriteriaList:
|
||||||
stopping_criteria = StoppingCriteriaList()
|
stopping_criteria = StoppingCriteriaList()
|
||||||
if max_length is not None:
|
if max_length is not None:
|
||||||
stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
|
stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
|
||||||
if max_time is not None:
|
if max_time is not None:
|
||||||
stopping_criteria.append(MaxTimeCriteria(max_time=max_time))
|
stopping_criteria.append(MaxTimeCriteria(max_time=max_time))
|
||||||
|
if max_new_tokens is not None:
|
||||||
|
stopping_criteria.append(MaxNewTokensCriteria(start_length=start_length, max_new_tokens=max_new_tokens))
|
||||||
return stopping_criteria
|
return stopping_criteria
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -661,6 +662,7 @@ class GenerationMixin:
|
|||||||
encoder_no_repeat_ngram_size: Optional[int] = None,
|
encoder_no_repeat_ngram_size: Optional[int] = None,
|
||||||
num_return_sequences: Optional[int] = None,
|
num_return_sequences: Optional[int] = None,
|
||||||
max_time: Optional[float] = None,
|
max_time: Optional[float] = None,
|
||||||
|
max_new_tokens: Optional[int] = None,
|
||||||
decoder_start_token_id: Optional[int] = None,
|
decoder_start_token_id: Optional[int] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
num_beam_groups: Optional[int] = None,
|
num_beam_groups: Optional[int] = None,
|
||||||
@@ -692,8 +694,11 @@ class GenerationMixin:
|
|||||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
|
The sequence used as a prompt for the generation. If :obj:`None` the method initializes it as an empty
|
||||||
:obj:`torch.LongTensor` of shape :obj:`(1,)`.
|
:obj:`torch.LongTensor` of shape :obj:`(1,)`.
|
||||||
max_length (:obj:`int`, `optional`, defaults to 20):
|
max_length (:obj:`int`, `optional`, defaults to :obj:`model.config.max_length`):
|
||||||
The maximum length of the sequence to be generated.
|
The maximum length of the sequence to be generated.
|
||||||
|
max_new_tokens (:obj:`int`, `optional`, defaults to None):
|
||||||
|
The maximum numbers of tokens to generate, ignore the current number of tokens. Use either
|
||||||
|
:obj:`max_new_tokens` or :obj:`max_length` but not both, they serve the same purpose.
|
||||||
min_length (:obj:`int`, `optional`, defaults to 10):
|
min_length (:obj:`int`, `optional`, defaults to 10):
|
||||||
The minimum length of the sequence to be generated.
|
The minimum length of the sequence to be generated.
|
||||||
do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
@@ -861,6 +866,15 @@ class GenerationMixin:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# set init values
|
# set init values
|
||||||
|
if max_length is None and max_new_tokens is None:
|
||||||
|
# Both are None, default
|
||||||
|
max_length = self.config.max_length
|
||||||
|
elif max_length is not None and max_new_tokens is not None:
|
||||||
|
# Both are set, this is odd, raise a warning
|
||||||
|
warnings.warn(
|
||||||
|
"Both `max_length` and `max_new_tokens` have been set but they serve the same purpose.", UserWarning
|
||||||
|
)
|
||||||
|
|
||||||
max_length = max_length if max_length is not None else self.config.max_length
|
max_length = max_length if max_length is not None else self.config.max_length
|
||||||
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||||
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
|
num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups
|
||||||
@@ -960,7 +974,10 @@ class GenerationMixin:
|
|||||||
remove_invalid_values=remove_invalid_values,
|
remove_invalid_values=remove_invalid_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
stopping_criteria = self._get_stopping_criteria(max_length=max_length, max_time=max_time)
|
cur_len = input_ids.shape[-1]
|
||||||
|
stopping_criteria = self._get_stopping_criteria(
|
||||||
|
max_length=max_length, max_time=max_time, max_new_tokens=max_new_tokens, start_length=cur_len
|
||||||
|
)
|
||||||
|
|
||||||
if is_greedy_gen_mode:
|
if is_greedy_gen_mode:
|
||||||
if num_return_sequences > 1:
|
if num_return_sequences > 1:
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ if is_torch_available():
|
|||||||
|
|
||||||
from transformers.generation_stopping_criteria import (
|
from transformers.generation_stopping_criteria import (
|
||||||
MaxLengthCriteria,
|
MaxLengthCriteria,
|
||||||
|
MaxNewTokensCriteria,
|
||||||
MaxTimeCriteria,
|
MaxTimeCriteria,
|
||||||
StoppingCriteriaList,
|
StoppingCriteriaList,
|
||||||
validate_stopping_criteria,
|
validate_stopping_criteria,
|
||||||
@@ -58,6 +59,21 @@ class StoppingCriteriaTestCase(unittest.TestCase):
|
|||||||
input_ids, scores = self._get_tensors(10)
|
input_ids, scores = self._get_tensors(10)
|
||||||
self.assertTrue(criteria(input_ids, scores))
|
self.assertTrue(criteria(input_ids, scores))
|
||||||
|
|
||||||
|
def test_max_new_tokens_criteria(self):
|
||||||
|
criteria = MaxNewTokensCriteria(start_length=5, max_new_tokens=5)
|
||||||
|
|
||||||
|
input_ids, scores = self._get_tensors(5)
|
||||||
|
self.assertFalse(criteria(input_ids, scores))
|
||||||
|
|
||||||
|
input_ids, scores = self._get_tensors(9)
|
||||||
|
self.assertFalse(criteria(input_ids, scores))
|
||||||
|
|
||||||
|
input_ids, scores = self._get_tensors(10)
|
||||||
|
self.assertTrue(criteria(input_ids, scores))
|
||||||
|
|
||||||
|
criteria_list = StoppingCriteriaList([criteria])
|
||||||
|
self.assertEqual(criteria_list.max_length, 10)
|
||||||
|
|
||||||
def test_max_time_criteria(self):
|
def test_max_time_criteria(self):
|
||||||
input_ids, scores = self._get_tensors(5)
|
input_ids, scores = self._get_tensors(5)
|
||||||
|
|
||||||
|
|||||||
@@ -1615,3 +1615,26 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
# BeamSearchScorer max_length should not influence "real" max_length
|
# BeamSearchScorer max_length should not influence "real" max_length
|
||||||
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())
|
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())
|
||||||
|
|
||||||
|
def test_max_new_tokens(self):
|
||||||
|
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||||
|
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||||
|
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||||
|
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||||
|
|
||||||
|
self.assertEqual(list(input_ids.shape), [1, 15])
|
||||||
|
|
||||||
|
# Encoder decoder call
|
||||||
|
max_new_tokens = 3
|
||||||
|
outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens)
|
||||||
|
# 1 BOS + 3 new tokens
|
||||||
|
self.assertEqual(list(outputs.shape), [1, 4])
|
||||||
|
|
||||||
|
# Decoder only call
|
||||||
|
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens)
|
||||||
|
# 15 + 3 new tokens
|
||||||
|
self.assertEqual(list(outputs.shape), [1, 18])
|
||||||
|
|
||||||
|
# max_new_tokens and max_length serve the same purpose and should not be used together.
|
||||||
|
with self.assertWarns(UserWarning):
|
||||||
|
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=10, max_length=20)
|
||||||
|
|||||||
Reference in New Issue
Block a user