@@ -1,3 +1,4 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoTokenizer, BatchEncoding, MBartTokenizer
|
||||
@@ -171,3 +172,13 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(ids[-2], 2)
|
||||
self.assertEqual(ids[-1], EN_CODE)
|
||||
self.assertEqual(len(ids), desired_max_length)
|
||||
|
||||
def test_mask_token(self):
|
||||
self.assertListEqual(self.tokenizer.convert_tokens_to_ids(["<mask>", "ar_AR"]), [250026, 250001])
|
||||
|
||||
def test_special_tokens_unaffacted_by_save_load(self):
|
||||
tmpdirname = tempfile.mkdtemp()
|
||||
original_special_tokens = self.tokenizer.fairseq_tokens_to_ids
|
||||
self.tokenizer.save_pretrained(tmpdirname)
|
||||
new_tok = MBartTokenizer.from_pretrained(tmpdirname)
|
||||
self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)
|
||||
|
||||
Reference in New Issue
Block a user