[generate] beam search -- fix output cropping (#37080)

* handle jagged beams

* better comment

* bart -- beam search tests print special tokens

* more bart test updates

* more tests!

* better comment
This commit is contained in:
Joao Gante
2025-03-28 17:57:51 +00:00
committed by GitHub
parent 257bc670fb
commit 9fd9476005
5 changed files with 74 additions and 45 deletions

View File

@@ -3931,9 +3931,14 @@ class GenerationMixin:
beam_scores = self._flatten_beam_dim(beam_scores[:, :num_return_sequences])
beam_indices = self._flatten_beam_dim(beam_indices[:, :num_return_sequences, :])
# Crop the static-shaped tensors to the actual size
sequences = sequences[:, :cur_len]
beam_indices = beam_indices[:, : cur_len - decoder_prompt_len]
# Crop the static-shaped tensors to the actual size.
# `beam_indices` is initialized with -1s, and is updated with the beam index of the generated token at each
# step. We can use it to detect the generated length, which may be != `cur_len` (e.g. selected beam is from a
# previous decoding iteration)
max_generated_length = ((beam_indices + 1).bool()).sum(dim=1).max()
output_length = decoder_prompt_len + max_generated_length
sequences = sequences[:, :output_length]
beam_indices = beam_indices[:, :max_generated_length]
if return_dict_in_generate:
if not output_scores:

View File

@@ -599,13 +599,15 @@ class FastIntegrationTests(unittest.TestCase):
" 2002 to prosecute genocide, crimes against humanity and war crimes."
)
EXPECTED = (
"</s>"
" The International Criminal Court (ICC) has announced that it has been announced by the International"
" Criminal court."
"</s>"
)
dct = tok(ARTICLE, return_tensors="pt")
generated_ids = hf.generate(**dct, num_beams=4)
result = tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
result = tok.batch_decode(generated_ids)[0]
assert EXPECTED == result
def test_xsum_1_1_batch_generation(self):
@@ -729,16 +731,18 @@ class FastIntegrationTests(unittest.TestCase):
truncation=True,
)
generated_ids = self.xsum_1_1_model.generate(**batch, num_beams=4)
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)
assert (
result[0]
== " The International Criminal Court (ICC) has announced that it has been announced by the International"
result = self.tok.batch_decode(generated_ids)
assert result[0] == (
"</s>"
" The International Criminal Court (ICC) has announced that it has been announced by the International"
" Criminal court."
"</s><pad><pad><pad><pad><pad>"
)
assert (
result[1]
== " An investigation into the crash that killed at least 10 people in the French capital has been"
assert result[1] == (
"</s>"
" An investigation into the crash that killed at least 10 people in the French capital has been"
" released by the French police investigating the crash."
"</s>"
)
def test_encoder_equiv(self):
@@ -939,8 +943,10 @@ class BartModelIntegrationTests(unittest.TestCase):
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
EXPECTED_SUMMARY = (
"</s>"
"California's largest power company has begun shutting off electricity to thousands of customers in the"
" state."
"</s>"
)
dct = tok.batch_encode_plus(
[PGE_ARTICLE],
@@ -962,10 +968,7 @@ class BartModelIntegrationTests(unittest.TestCase):
decoder_start_token_id=model.config.eos_token_id,
)
decoded = tok.batch_decode(
hypotheses_batch,
skip_special_tokens=True,
)
decoded = tok.batch_decode(hypotheses_batch)
self.assertEqual(EXPECTED_SUMMARY, decoded[0])
def test_xsum_config_generation_params(self):
@@ -1189,26 +1192,32 @@ class BartModelIntegrationTests(unittest.TestCase):
assert hypotheses_batch[:, 1].eq(0).all().item()
EXPECTED = [
"</s><s>"
"A French prosecutor says he is not aware of any video footage from on board the plane. Two German "
"magazines claim to have found a cell phone video showing the crash. The publications say they watched "
"the video, which was found by a source close to the investigation. All 150 on board Germanwings Flight "
"9525 were killed.",
"9525 were killed."
"</s>",
"</s><s>"
"Palestinian Authority becomes 123rd member of the International Criminal Court. The move gives the court "
"jurisdiction over alleged crimes in Palestinian territories. Israel and the United States opposed the "
"Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki said it was a "
"move toward greater justice.",
"move toward greater justice."
"</s><pad><pad><pad><pad>",
"</s><s>"
"U.S. and its negotiating partners reached a strong framework agreement with Iran. Peter Bergen: The "
"debate that has already begun will likely result in more heat than light. He says critics have made "
"dubious assumptions and doubtful assertions. Bergen says the goal was to block Iran from building a "
"nuclear weapon.",
"nuclear weapon."
"</s><pad><pad><pad>",
"</s><s>"
"Liana Barrientos, 39, has been married 10 times, sometimes within two weeks of each other. Prosecutors "
"say the marriages were part of an immigration scam. She pleaded not guilty at State Supreme Court in the "
"Bronx on Friday. If convicted, she faces up to four years in prison.",
"Bronx on Friday. If convicted, she faces up to four years in prison."
"</s><pad><pad><pad><pad><pad>",
]
generated_summaries = tok.batch_decode(
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
)
generated_summaries = tok.batch_decode(hypotheses_batch.tolist())
assert generated_summaries == EXPECTED
@slow

View File

@@ -434,7 +434,7 @@ class BioGptModelIntegrationTest(unittest.TestCase):
torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)
@slow
def test_biogpt_generation(self):
def test_biogpt_generation_beam_search(self):
tokenizer = BioGptTokenizer.from_pretrained("microsoft/biogpt")
model = BioGptForCausalLM.from_pretrained("microsoft/biogpt")
model.to(torch_device)
@@ -448,13 +448,15 @@ class BioGptModelIntegrationTest(unittest.TestCase):
num_beams=5,
early_stopping=True,
)
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
output_str = tokenizer.decode(output_ids[0])
EXPECTED_OUTPUT_STR = (
"</s>"
"COVID-19 is a global pandemic caused by severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2), the"
" causative agent of coronavirus disease 2019 (COVID-19), which has spread to more than 200 countries and"
" territories, including the United States (US), Canada, Australia, New Zealand, the United Kingdom (UK),"
" and the United States of America (USA), as of March 11, 2020, with more than 800,000 confirmed cases and"
" more than 800,000 deaths."
" more than 800,000 deaths. "
"</s>"
)
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)

View File

@@ -415,16 +415,20 @@ class M2M100ModelIntegrationTests(unittest.TestCase):
)
expected_en = [
"The NSA case highlights the total absence of intelligence debate",
"I think there are two levels of response from the French government.",
"</s> __en__ "
"The NSA case highlights the total absence of intelligence debate"
"</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>",
"</s> __en__ "
"I think there are two levels of response from the French government."
"</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>",
"</s> __en__ "
"When François Hollande calls Barack Obama or when Foreign Minister Laurent Fabius calls the U.S."
" Ambassador, they respond to a real discovery, which is that of the scale of U.S. surveillance on all"
" communications in France.",
" communications in France."
"</s>",
]
generated = tokenizer.batch_decode(
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
)
generated = tokenizer.batch_decode(hypotheses_batch)
assert generated == expected_en
@require_flash_attn

View File

@@ -1475,19 +1475,27 @@ class T5ModelIntegrationTests(unittest.TestCase):
)
expected_summaries = [
"<pad> "
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a'
" cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one"
" magazine says .",
" magazine says ."
"</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>",
"<pad> "
"the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a"
" preliminary examination into the situation in the occupied Palestinian territory . as members of the"
" court, Palestinians may be subject to counter-charges as well .",
" court, Palestinians may be subject to counter-charges as well ."
"</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>",
"<pad> "
"the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:"
" the debate that has already begun since the announcement of the new framework will likely result in more"
" heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and"
" implement a rigorous inspection regime .",
" implement a rigorous inspection regime ."
"</s>",
"<pad> "
"prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two"
' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10'
" times, with nine of her marriages occurring between 1999 and 2002 .",
" times, with nine of her marriages occurring between 1999 and 2002 ."
"</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>",
]
use_task_specific_params(model, "summarization")
@@ -1512,11 +1520,8 @@ class T5ModelIntegrationTests(unittest.TestCase):
early_stopping=True,
)
decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False)
self.assertListEqual(
expected_summaries,
decoded,
)
decoded = tok.batch_decode(hypotheses_batch)
self.assertListEqual(expected_summaries, decoded)
@slow
def test_translation_en_to_de(self):
@@ -1526,13 +1531,13 @@ class T5ModelIntegrationTests(unittest.TestCase):
en_text = '"Luigi often said to me that he never wanted the brothers to end up in court", she wrote.'
expected_translation = (
'"Luigi sagte mir oft, dass er nie wollte, dass die Brüder am Gericht sitzen", schrieb sie.'
'<pad> "Luigi sagte mir oft, dass er nie wollte, dass die Brüder am Gericht sitzen", schrieb sie.</s>'
)
input_ids = tok.encode(model.config.prefix + en_text, return_tensors="pt")
input_ids = input_ids.to(torch_device)
output = model.generate(input_ids)
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
translation = tok.decode(output[0])
self.assertEqual(translation, expected_translation)
@slow
@@ -1558,13 +1563,15 @@ class T5ModelIntegrationTests(unittest.TestCase):
do_sample=False,
early_stopping=True,
)
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
translation = tok.decode(output[0])
new_truncated_translation = (
"<pad> "
"Cette section d'images provenant de l'enregistrement infrarouge effectué par le télescope Spitzer montre "
"un "
"« portrait familial » de générations innombrables détoiles : les plus anciennes sont observées "
"sous forme "
"de points bleus."
"</s>"
)
self.assertEqual(translation, new_truncated_translation)
@@ -1575,11 +1582,13 @@ class T5ModelIntegrationTests(unittest.TestCase):
tok = self.tokenizer
use_task_specific_params(model, "translation_en_to_ro")
en_text = "Taco Bell said it plans to add 2,000 locations in the US by 2022."
expected_translation = "Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022."
expected_translation = (
"<pad> Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022.</s>"
)
inputs = tok(model.config.prefix + en_text, return_tensors="pt").to(torch_device)
output = model.generate(**inputs)
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
translation = tok.decode(output[0])
self.assertEqual(translation, expected_translation)
@slow