From e11d923bfc61ed640bc7e696549578361126485e Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 25 Aug 2020 14:06:28 -0400 Subject: [PATCH] Fix pegasus-xsum integration test (#6726) --- tests/test_modeling_pegasus.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_modeling_pegasus.py b/tests/test_modeling_pegasus.py index ac6be27210..6fb387daa7 100644 --- a/tests/test_modeling_pegasus.py +++ b/tests/test_modeling_pegasus.py @@ -20,8 +20,8 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest): checkpoint_name = "google/pegasus-xsum" src_text = [PGE_ARTICLE, XSUM_ENTRY_LONGER] tgt_text = [ - "California's largest electricity provider has turned off power to tens of thousands of customers.", - "N-Dubz have revealed they weren't expecting to get four nominations at this year's Mobo Awards.", + "California's largest electricity provider has turned off power to hundreds of thousands of customers.", + "N-Dubz have said they were surprised to get four nominations for this year's Mobo Awards.", ] @cached_property @@ -37,7 +37,7 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest): assert inputs.input_ids.shape == (2, 421) translated_tokens = self.model.generate(**inputs) decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True) - self.assertEqual(self.tgt_text, decoded) + assert self.tgt_text == decoded if "cuda" not in torch_device: return