[generate] beam search -- fix output cropping (#37080)
* handle jagged beams * better comment * bart -- beam search tests print special tokens * more bart test updates * more tests! * better comment
This commit is contained in:
@@ -599,13 +599,15 @@ class FastIntegrationTests(unittest.TestCase):
|
||||
" 2002 to prosecute genocide, crimes against humanity and war crimes."
|
||||
)
|
||||
EXPECTED = (
|
||||
"</s>"
|
||||
" The International Criminal Court (ICC) has announced that it has been announced by the International"
|
||||
" Criminal court."
|
||||
"</s>"
|
||||
)
|
||||
|
||||
dct = tok(ARTICLE, return_tensors="pt")
|
||||
generated_ids = hf.generate(**dct, num_beams=4)
|
||||
result = tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
result = tok.batch_decode(generated_ids)[0]
|
||||
assert EXPECTED == result
|
||||
|
||||
def test_xsum_1_1_batch_generation(self):
|
||||
@@ -729,16 +731,18 @@ class FastIntegrationTests(unittest.TestCase):
|
||||
truncation=True,
|
||||
)
|
||||
generated_ids = self.xsum_1_1_model.generate(**batch, num_beams=4)
|
||||
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
assert (
|
||||
result[0]
|
||||
== " The International Criminal Court (ICC) has announced that it has been announced by the International"
|
||||
result = self.tok.batch_decode(generated_ids)
|
||||
assert result[0] == (
|
||||
"</s>"
|
||||
" The International Criminal Court (ICC) has announced that it has been announced by the International"
|
||||
" Criminal court."
|
||||
"</s><pad><pad><pad><pad><pad>"
|
||||
)
|
||||
assert (
|
||||
result[1]
|
||||
== " An investigation into the crash that killed at least 10 people in the French capital has been"
|
||||
assert result[1] == (
|
||||
"</s>"
|
||||
" An investigation into the crash that killed at least 10 people in the French capital has been"
|
||||
" released by the French police investigating the crash."
|
||||
"</s>"
|
||||
)
|
||||
|
||||
def test_encoder_equiv(self):
|
||||
@@ -939,8 +943,10 @@ class BartModelIntegrationTests(unittest.TestCase):
|
||||
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
|
||||
|
||||
EXPECTED_SUMMARY = (
|
||||
"</s>"
|
||||
"California's largest power company has begun shutting off electricity to thousands of customers in the"
|
||||
" state."
|
||||
"</s>"
|
||||
)
|
||||
dct = tok.batch_encode_plus(
|
||||
[PGE_ARTICLE],
|
||||
@@ -962,10 +968,7 @@ class BartModelIntegrationTests(unittest.TestCase):
|
||||
decoder_start_token_id=model.config.eos_token_id,
|
||||
)
|
||||
|
||||
decoded = tok.batch_decode(
|
||||
hypotheses_batch,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
decoded = tok.batch_decode(hypotheses_batch)
|
||||
self.assertEqual(EXPECTED_SUMMARY, decoded[0])
|
||||
|
||||
def test_xsum_config_generation_params(self):
|
||||
@@ -1189,26 +1192,32 @@ class BartModelIntegrationTests(unittest.TestCase):
|
||||
assert hypotheses_batch[:, 1].eq(0).all().item()
|
||||
|
||||
EXPECTED = [
|
||||
"</s><s>"
|
||||
"A French prosecutor says he is not aware of any video footage from on board the plane. Two German "
|
||||
"magazines claim to have found a cell phone video showing the crash. The publications say they watched "
|
||||
"the video, which was found by a source close to the investigation. All 150 on board Germanwings Flight "
|
||||
"9525 were killed.",
|
||||
"9525 were killed."
|
||||
"</s>",
|
||||
"</s><s>"
|
||||
"Palestinian Authority becomes 123rd member of the International Criminal Court. The move gives the court "
|
||||
"jurisdiction over alleged crimes in Palestinian territories. Israel and the United States opposed the "
|
||||
"Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki said it was a "
|
||||
"move toward greater justice.",
|
||||
"move toward greater justice."
|
||||
"</s><pad><pad><pad><pad>",
|
||||
"</s><s>"
|
||||
"U.S. and its negotiating partners reached a strong framework agreement with Iran. Peter Bergen: The "
|
||||
"debate that has already begun will likely result in more heat than light. He says critics have made "
|
||||
"dubious assumptions and doubtful assertions. Bergen says the goal was to block Iran from building a "
|
||||
"nuclear weapon.",
|
||||
"nuclear weapon."
|
||||
"</s><pad><pad><pad>",
|
||||
"</s><s>"
|
||||
"Liana Barrientos, 39, has been married 10 times, sometimes within two weeks of each other. Prosecutors "
|
||||
"say the marriages were part of an immigration scam. She pleaded not guilty at State Supreme Court in the "
|
||||
"Bronx on Friday. If convicted, she faces up to four years in prison.",
|
||||
"Bronx on Friday. If convicted, she faces up to four years in prison."
|
||||
"</s><pad><pad><pad><pad><pad>",
|
||||
]
|
||||
|
||||
generated_summaries = tok.batch_decode(
|
||||
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
|
||||
)
|
||||
generated_summaries = tok.batch_decode(hypotheses_batch.tolist())
|
||||
assert generated_summaries == EXPECTED
|
||||
|
||||
@slow
|
||||
|
||||
Reference in New Issue
Block a user