Update expected values in constrained beam search tests (#17887)

* fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-06-28 08:53:53 +02:00
committed by GitHub
parent e02037b352
commit 0b0dd97737

View File

@@ -2510,8 +2510,8 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertListEqual( self.assertListEqual(
generated_text, generated_text,
[ [
"The soldiers were not prepared and didn't know how big the big weapons would be, so they scared them" "The soldiers were not prepared and didn't know what to do. They had no idea how they would react if"
" off. They had no idea what to do", " the enemy attacked them, big weapons scared"
], ],
) )
@@ -2549,8 +2549,9 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertListEqual( self.assertListEqual(
generated_text, generated_text,
[ [
"The soldiers, who were all scared and screaming at each other as they tried to get out of the", "The soldiers, who had been stationed at the base for more than a year before being evacuated"
"The child was taken to a local hospital where she screamed and scared for her life, police said.", " screaming scared",
"The child was taken to a local hospital where he died.\n 'I don't think screaming scared",
], ],
) )
@@ -2585,8 +2586,9 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertListEqual( self.assertListEqual(
generated_text, generated_text,
[ [
"The soldiers, who were all scared and screaming at each other as they tried to get out of the", "The soldiers, who had been stationed at the base for more than a year before being evacuated"
"The child was taken to a local hospital where she screamed and scared for her life, police said.", " screaming scared",
"The child was taken to a local hospital where he died.\n 'I don't think screaming scared",
], ],
) )
@@ -2612,7 +2614,7 @@ class GenerationIntegrationTests(unittest.TestCase):
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(outputs, ["Wie alter sind Sie?"]) self.assertListEqual(outputs, ["Wie alt sind Sie?"])
@slow @slow
def test_constrained_beam_search_example_integration(self): def test_constrained_beam_search_example_integration(self):
@@ -2656,7 +2658,7 @@ class GenerationIntegrationTests(unittest.TestCase):
) )
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(outputs, ["Wie alter sind Sie?"]) self.assertListEqual(outputs, ["Wie alt sind Sie?"])
def test_constrained_beam_search_mixin_type_checks(self): def test_constrained_beam_search_mixin_type_checks(self):
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random") tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")