Fix pegasus-xsum integration test (#6726)
This commit is contained in:
@@ -20,8 +20,8 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
|||||||
checkpoint_name = "google/pegasus-xsum"
|
checkpoint_name = "google/pegasus-xsum"
|
||||||
src_text = [PGE_ARTICLE, XSUM_ENTRY_LONGER]
|
src_text = [PGE_ARTICLE, XSUM_ENTRY_LONGER]
|
||||||
tgt_text = [
|
tgt_text = [
|
||||||
"California's largest electricity provider has turned off power to tens of thousands of customers.",
|
"California's largest electricity provider has turned off power to hundreds of thousands of customers.",
|
||||||
"N-Dubz have revealed they weren't expecting to get four nominations at this year's Mobo Awards.",
|
"N-Dubz have said they were surprised to get four nominations for this year's Mobo Awards.",
|
||||||
]
|
]
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
@@ -37,7 +37,7 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
|||||||
assert inputs.input_ids.shape == (2, 421)
|
assert inputs.input_ids.shape == (2, 421)
|
||||||
translated_tokens = self.model.generate(**inputs)
|
translated_tokens = self.model.generate(**inputs)
|
||||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||||
self.assertEqual(self.tgt_text, decoded)
|
assert self.tgt_text == decoded
|
||||||
|
|
||||||
if "cuda" not in torch_device:
|
if "cuda" not in torch_device:
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user