Cleanup pytorch tests (#8033)

This commit is contained in:
Sam Shleifer
2020-10-26 08:59:06 -04:00
committed by GitHub
parent 20a0894d1a
commit 8bbe8247f1
3 changed files with 3 additions and 31 deletions

View File

@@ -58,7 +58,7 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
src_text = [PGE_ARTICLE, XSUM_ENTRY_LONGER]
tgt_text = [
"California's largest electricity provider has turned off power to hundreds of thousands of customers.",
"N-Dubz have said they were surprised to get four nominations for this year's Mobo Awards.",
"Pop group N-Dubz have revealed they were surprised to get four nominations for this year's Mobo Awards.",
]
@cached_property
@@ -72,7 +72,7 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
torch_device
)
assert inputs.input_ids.shape == (2, 421)
translated_tokens = self.model.generate(**inputs)
translated_tokens = self.model.generate(**inputs, num_beams=2)
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
assert self.tgt_text == decoded