[generate] beam search -- fix output cropping (#37080)
* handle jagged beams * better comment * bart -- beam search tests print special tokens * more bart test updates * more tests! * better comment
This commit is contained in:
@@ -1475,19 +1475,27 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
expected_summaries = [
|
||||
"<pad> "
|
||||
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a'
|
||||
" cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one"
|
||||
" magazine says .",
|
||||
" magazine says ."
|
||||
"</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>",
|
||||
"<pad> "
|
||||
"the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a"
|
||||
" preliminary examination into the situation in the occupied Palestinian territory . as members of the"
|
||||
" court, Palestinians may be subject to counter-charges as well .",
|
||||
" court, Palestinians may be subject to counter-charges as well ."
|
||||
"</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>",
|
||||
"<pad> "
|
||||
"the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:"
|
||||
" the debate that has already begun since the announcement of the new framework will likely result in more"
|
||||
" heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and"
|
||||
" implement a rigorous inspection regime .",
|
||||
" implement a rigorous inspection regime ."
|
||||
"</s>",
|
||||
"<pad> "
|
||||
"prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two"
|
||||
' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10'
|
||||
" times, with nine of her marriages occurring between 1999 and 2002 .",
|
||||
" times, with nine of her marriages occurring between 1999 and 2002 ."
|
||||
"</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>",
|
||||
]
|
||||
|
||||
use_task_specific_params(model, "summarization")
|
||||
@@ -1512,11 +1520,8 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
early_stopping=True,
|
||||
)
|
||||
|
||||
decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
self.assertListEqual(
|
||||
expected_summaries,
|
||||
decoded,
|
||||
)
|
||||
decoded = tok.batch_decode(hypotheses_batch)
|
||||
self.assertListEqual(expected_summaries, decoded)
|
||||
|
||||
@slow
|
||||
def test_translation_en_to_de(self):
|
||||
@@ -1526,13 +1531,13 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
en_text = '"Luigi often said to me that he never wanted the brothers to end up in court", she wrote.'
|
||||
expected_translation = (
|
||||
'"Luigi sagte mir oft, dass er nie wollte, dass die Brüder am Gericht sitzen", schrieb sie.'
|
||||
'<pad> "Luigi sagte mir oft, dass er nie wollte, dass die Brüder am Gericht sitzen", schrieb sie.</s>'
|
||||
)
|
||||
|
||||
input_ids = tok.encode(model.config.prefix + en_text, return_tensors="pt")
|
||||
input_ids = input_ids.to(torch_device)
|
||||
output = model.generate(input_ids)
|
||||
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
translation = tok.decode(output[0])
|
||||
self.assertEqual(translation, expected_translation)
|
||||
|
||||
@slow
|
||||
@@ -1558,13 +1563,15 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
do_sample=False,
|
||||
early_stopping=True,
|
||||
)
|
||||
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
translation = tok.decode(output[0])
|
||||
new_truncated_translation = (
|
||||
"<pad> "
|
||||
"Cette section d'images provenant de l'enregistrement infrarouge effectué par le télescope Spitzer montre "
|
||||
"un "
|
||||
"« portrait familial » de générations innombrables d’étoiles : les plus anciennes sont observées "
|
||||
"sous forme "
|
||||
"de points bleus."
|
||||
"</s>"
|
||||
)
|
||||
|
||||
self.assertEqual(translation, new_truncated_translation)
|
||||
@@ -1575,11 +1582,13 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
tok = self.tokenizer
|
||||
use_task_specific_params(model, "translation_en_to_ro")
|
||||
en_text = "Taco Bell said it plans to add 2,000 locations in the US by 2022."
|
||||
expected_translation = "Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022."
|
||||
expected_translation = (
|
||||
"<pad> Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022.</s>"
|
||||
)
|
||||
|
||||
inputs = tok(model.config.prefix + en_text, return_tensors="pt").to(torch_device)
|
||||
output = model.generate(**inputs)
|
||||
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
translation = tok.decode(output[0])
|
||||
self.assertEqual(translation, expected_translation)
|
||||
|
||||
@slow
|
||||
|
||||
Reference in New Issue
Block a user