Fix #6096: MBartTokenizer's mask token (#6098)

This commit is contained in:
Sam Shleifer
2020-07-28 18:27:58 -04:00
committed by GitHub
parent b1c8b76907
commit 5abe50381a
3 changed files with 24 additions and 0 deletions

View File

@@ -1,3 +1,4 @@
import tempfile
import unittest
from transformers import AutoTokenizer, BatchEncoding, MBartTokenizer
@@ -171,3 +172,13 @@ class MBartEnroIntegrationTest(unittest.TestCase):
self.assertEqual(ids[-2], 2)
self.assertEqual(ids[-1], EN_CODE)
self.assertEqual(len(ids), desired_max_length)
def test_mask_token(self):
self.assertListEqual(self.tokenizer.convert_tokens_to_ids(["<mask>", "ar_AR"]), [250026, 250001])
def test_special_tokens_unaffacted_by_save_load(self):
tmpdirname = tempfile.mkdtemp()
original_special_tokens = self.tokenizer.fairseq_tokens_to_ids
self.tokenizer.save_pretrained(tmpdirname)
new_tok = MBartTokenizer.from_pretrained(tmpdirname)
self.assertDictEqual(new_tok.fairseq_tokens_to_ids, original_special_tokens)