Adding new parameter to generate: max_time. (#9846)
* [WIP] Adding new parameter to `generate`: `max_time`. Generation by tokens number is sometimes a bit clunky because we don't know how many tokens are good enough or even how many tokens are in the payload (for pipelines users for instance). This leads to hard to understand behavior. This PR proposes a new argument `max_time` which is a float of seconds for the allowed time for `generate` to run on. Ideally combinations of `max_tokens=None`, `max_time=2` could be used to generate as many tokens as possible within time budget. NB: Another possible approach consists of passing a callback to `generate` putting the caller in charge of the actual decision of when to stop generating tokens. It opens the door to 'which args should we pass' to this callback. It's hard to imagine other use-cases for this early stopping behavior than time (that are not already covered by parameters of generate) * Revamp with StoppingCriteria * Removing deprecated mentions. * Forgot arguments to stopping criteria. * Readding max_length it's not just used as a stopping criteria. * Default value for `stopping_criteria`. * Address @patrickvonplaten comments. - More docstrings - Actual doc - Include in global namespace - Remove TF work. * Put back `max_length` (deprecation different PR). * Doc quality. * Fixing old behavior without `stopping_criteria` but with `max_length`. Making sure we don't break that in the future. * Adding more tests for possible inconsistencies between `max_length` and `stopping_criteria`. * Fixing the torch imports.
This commit is contained in:
@@ -38,6 +38,7 @@ if is_torch_available():
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
from transformers.generation_stopping_criteria import MaxLengthCriteria, StoppingCriteriaList
|
||||
from transformers.generation_utils import (
|
||||
BeamSampleDecoderOnlyOutput,
|
||||
BeamSampleEncoderDecoderOutput,
|
||||
@@ -1320,3 +1321,189 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
"Justin Timberlake and Jessica Biel have a son. The baby is named Silas Randall Timberlake. It is the first child for both. The couple announced the pregnancy in January. The name Silas is the middle name of Timberlake's maternal grandfather. It's also his own middle name.",
|
||||
],
|
||||
)
|
||||
|
||||
def test_max_length_backward_compat_greedy(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
max_length = 20
|
||||
input_ids = input_ids.expand(2, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids,
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
|
||||
bart_model.greedy_search(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
pad_token_id=bart_model.config.pad_token_id,
|
||||
eos_token_id=bart_model.config.eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def test_max_length_backward_compat_sample(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
max_length = 20
|
||||
input_ids = input_ids.expand(2, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids,
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
bart_model.sample(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
pad_token_id=bart_model.config.pad_token_id,
|
||||
eos_token_id=bart_model.config.eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def test_max_length_backward_compat_beam_search(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
batch_size = 1
|
||||
max_length = 20
|
||||
num_beams = 2
|
||||
|
||||
input_ids = input_ids.expand(2, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids,
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
)
|
||||
_ = bart_model.beam_search(
|
||||
input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs
|
||||
)
|
||||
|
||||
def test_max_length_backward_compat_group_beam_search(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
batch_size = 1
|
||||
max_length = 20
|
||||
num_beams = 6
|
||||
num_beam_groups = 3
|
||||
num_return_sequences = num_beams * batch_size
|
||||
|
||||
input_ids = input_ids.expand(6, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids,
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
|
||||
diverse_beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
num_beam_groups=num_beam_groups,
|
||||
)
|
||||
bart_model.group_beam_search(
|
||||
input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs
|
||||
)
|
||||
|
||||
def test_max_length_warning_if_different(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
batch_size = 1
|
||||
|
||||
max_length = 20
|
||||
num_beams = 6
|
||||
num_beam_groups = 3
|
||||
num_return_sequences = num_beams * batch_size
|
||||
stopping_criteria_max_length = 18
|
||||
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)])
|
||||
|
||||
# Greedy
|
||||
input_ids = input_ids.expand(6, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
|
||||
input_ids,
|
||||
decoder_start_token_id=bart_model.config.decoder_start_token_id,
|
||||
bos_token_id=bart_model.config.bos_token_id,
|
||||
)
|
||||
|
||||
with self.assertWarns(UserWarning):
|
||||
bart_model.greedy_search(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
pad_token_id=bart_model.config.pad_token_id,
|
||||
stopping_criteria=stopping_criteria,
|
||||
eos_token_id=bart_model.config.eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# Sample
|
||||
with self.assertWarns(UserWarning):
|
||||
bart_model.sample(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
stopping_criteria=stopping_criteria,
|
||||
pad_token_id=bart_model.config.pad_token_id,
|
||||
eos_token_id=bart_model.config.eos_token_id,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# Beam
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
)
|
||||
with self.assertWarns(UserWarning):
|
||||
bart_model.beam_search(
|
||||
input_ids,
|
||||
num_beams=num_beams,
|
||||
stopping_criteria=stopping_criteria,
|
||||
max_length=max_length,
|
||||
beam_scorer=beam_scorer,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# Grouped beam search
|
||||
diverse_beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
num_beam_groups=num_beam_groups,
|
||||
)
|
||||
with self.assertWarns(UserWarning):
|
||||
bart_model.group_beam_search(
|
||||
input_ids,
|
||||
diverse_beam_scorer,
|
||||
stopping_criteria=stopping_criteria,
|
||||
num_beams=num_beams,
|
||||
max_length=max_length,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user