🚨🚨 Generate: standardize beam search behavior across frameworks (#21368)

This commit is contained in:
Joao Gante
2023-02-03 10:24:02 +00:00
committed by GitHub
parent ea55bd86b9
commit f21af26279
10 changed files with 122 additions and 118 deletions

View File

@@ -2034,59 +2034,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
**model_kwargs,
)
def test_beam_search_warning_if_max_length_is_passed(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
torch_device
)
batch_size = 1
num_beams = 3
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
input_ids = input_ids.expand(num_beams, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
# pretend decoder_input_ids correspond to first encoder input id
decoder_input_ids = input_ids[:, :1]
stopping_criteria_max_length = 18
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)])
with self.assertWarns(UserWarning):
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=torch_device,
max_length=10,
)
generated_ids = bart_model.beam_search(
decoder_input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
beam_scorer=beam_scorer,
**model_kwargs,
)
beam_scorer_no_max_len = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=torch_device,
)
generated_ids_no_max_len = bart_model.beam_search(
decoder_input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
beam_scorer=beam_scorer_no_max_len,
**model_kwargs,
)
# BeamSearchScorer max_length should not influence "real" max_length
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())
def test_custom_stopping_criteria_overload_error(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")

View File

@@ -426,7 +426,7 @@ class FlaxBartModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationT
)
input_ids = tokenizer(input_str, return_tensors="np").input_ids
sequences = model.generate(input_ids, num_beams=2, max_length=20).sequences
sequences = model.generate(input_ids, num_beams=2, min_length=None, max_length=20).sequences
output_str = tokenizer.batch_decode(sequences)[0]

View File

@@ -224,7 +224,7 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittes
output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
expected_string = [
"Hello this is a long string of words. I'm going to try to explain what I mean.",
"Hello this is a long string of words. I'm going to start with the first one.\n",
"Hey, I'm not sure if I'm going to be able to do",
]

View File

@@ -1076,7 +1076,7 @@ class FlaxT5ModelIntegrationTests(unittest.TestCase):
expected_summaries = [
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a'
" cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one"
" magazine says . all 150 on board were killed when germanwings flight 9525 crashed .",
" magazine says . all 150 on board were killed in the crash .",
"the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a"
" preliminary examination into the situation in the occupied Palestinian territory . as members of the"
" court, Palestinians may be subject to counter-charges as well .",