@@ -123,6 +123,7 @@ class MBartEnroIntegrationTest(AbstractMBartIntegrationTest):
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
|
||||
@require_torch
|
||||
class MBartCC25IntegrationTest(AbstractMBartIntegrationTest):
|
||||
checkpoint_name = "facebook/mbart-large-cc25"
|
||||
src_text = [
|
||||
@@ -140,3 +141,14 @@ class MBartCC25IntegrationTest(AbstractMBartIntegrationTest):
|
||||
)
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||
self.assertEqual(self.tgt_text[0], decoded[0])
|
||||
|
||||
@slow
|
||||
def test_fill_mask(self):
|
||||
inputs = self.tokenizer.prepare_translation_batch(["One of the best <mask> I ever read!"]).to(torch_device)
|
||||
outputs = self.model.generate(
|
||||
inputs["input_ids"], decoder_start_token_id=self.tokenizer.lang_code_to_id["en_XX"], num_beams=1
|
||||
)
|
||||
prediction: str = self.tokenizer.batch_decode(
|
||||
outputs, clean_up_tokenization_spaces=True, skip_special_tokens=True
|
||||
)[0]
|
||||
self.assertEqual(prediction, "of the best books I ever read!")
|
||||
|
||||
Reference in New Issue
Block a user