Fix integration slow tests (#10670)

* PoC

* Fix slow tests for the PT1.8 Embedding problem
This commit is contained in:
Sylvain Gugger
2021-03-11 13:43:53 -05:00
committed by GitHub
parent 3ab6820370
commit fda703a553
9 changed files with 47 additions and 55 deletions

View File

@@ -343,7 +343,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
]
tgt_text = [
"Şeful ONU declară că nu există o soluţie militară în Siria",
'Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu vor face decât să înrăutăţească violenţa şi mizeria pentru milioane de oameni.',
'Secretarul General Ban Ki-moon declară că răspunsul său la intensificarea sprijinului militar al Rusiei pentru Siria este că "nu există o soluţie militară" la conflictul de aproape cinci ani şi că noi arme nu vor face decât să înrăutăţească violenţa şi mizeria a milioane de oameni.',
]
expected_src_tokens = [8274, 127873, 25916, 7, 8622, 2071, 438, 67485, 53, 187895, 23, 51712, 2, 250004]
@@ -359,7 +359,9 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
@slow
def test_enro_generate_batch(self):
batch: BatchEncoding = self.tokenizer(self.src_text, return_tensors="pt").to(torch_device)
batch: BatchEncoding = self.tokenizer(self.src_text, return_tensors="pt", padding=True, truncation=True).to(
torch_device
)
translated_tokens = self.model.generate(**batch)
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
assert self.tgt_text == decoded