[generate] beam search -- fix output cropping (#37080)

* handle jagged beams

* better comment

* bart -- beam search tests print special tokens

* more bart test updates

* more tests!

* better comment
This commit is contained in:
Joao Gante
2025-03-28 17:57:51 +00:00
committed by GitHub
parent 257bc670fb
commit 9fd9476005
5 changed files with 74 additions and 45 deletions

View File

@@ -1475,19 +1475,27 @@ class T5ModelIntegrationTests(unittest.TestCase):
)
expected_summaries = [
"<pad> "
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a'
" cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one"
" magazine says .",
" magazine says ."
"</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>",
"<pad> "
"the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a"
" preliminary examination into the situation in the occupied Palestinian territory . as members of the"
" court, Palestinians may be subject to counter-charges as well .",
" court, Palestinians may be subject to counter-charges as well ."
"</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>",
"<pad> "
"the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:"
" the debate that has already begun since the announcement of the new framework will likely result in more"
" heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and"
" implement a rigorous inspection regime .",
" implement a rigorous inspection regime ."
"</s>",
"<pad> "
"prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two"
' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10'
" times, with nine of her marriages occurring between 1999 and 2002 .",
" times, with nine of her marriages occurring between 1999 and 2002 ."
"</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>",
]
use_task_specific_params(model, "summarization")
@@ -1512,11 +1520,8 @@ class T5ModelIntegrationTests(unittest.TestCase):
early_stopping=True,
)
decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False)
self.assertListEqual(
expected_summaries,
decoded,
)
decoded = tok.batch_decode(hypotheses_batch)
self.assertListEqual(expected_summaries, decoded)
@slow
def test_translation_en_to_de(self):
@@ -1526,13 +1531,13 @@ class T5ModelIntegrationTests(unittest.TestCase):
en_text = '"Luigi often said to me that he never wanted the brothers to end up in court", she wrote.'
expected_translation = (
'"Luigi sagte mir oft, dass er nie wollte, dass die Brüder am Gericht sitzen", schrieb sie.'
'<pad> "Luigi sagte mir oft, dass er nie wollte, dass die Brüder am Gericht sitzen", schrieb sie.</s>'
)
input_ids = tok.encode(model.config.prefix + en_text, return_tensors="pt")
input_ids = input_ids.to(torch_device)
output = model.generate(input_ids)
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
translation = tok.decode(output[0])
self.assertEqual(translation, expected_translation)
@slow
@@ -1558,13 +1563,15 @@ class T5ModelIntegrationTests(unittest.TestCase):
do_sample=False,
early_stopping=True,
)
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
translation = tok.decode(output[0])
new_truncated_translation = (
"<pad> "
"Cette section d'images provenant de l'enregistrement infrarouge effectué par le télescope Spitzer montre "
"un "
"« portrait familial » de générations innombrables détoiles : les plus anciennes sont observées "
"sous forme "
"de points bleus."
"</s>"
)
self.assertEqual(translation, new_truncated_translation)
@@ -1575,11 +1582,13 @@ class T5ModelIntegrationTests(unittest.TestCase):
tok = self.tokenizer
use_task_specific_params(model, "translation_en_to_ro")
en_text = "Taco Bell said it plans to add 2,000 locations in the US by 2022."
expected_translation = "Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022."
expected_translation = (
"<pad> Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022.</s>"
)
inputs = tok(model.config.prefix + en_text, return_tensors="pt").to(torch_device)
output = model.generate(**inputs)
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
translation = tok.decode(output[0])
self.assertEqual(translation, expected_translation)
@slow