Exclude the madeup words from M2M100Tokenizer.vocab_size (#20976)
This commit is contained in:
@@ -193,7 +193,7 @@ class M2M100Tokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab_size(self) -> int:
|
def vocab_size(self) -> int:
|
||||||
return len(self.encoder) + len(self.lang_token_to_id) + self.num_madeup_words
|
return len(self.encoder) + len(self.lang_token_to_id)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def src_lang(self) -> str:
|
def src_lang(self) -> str:
|
||||||
|
|||||||
@@ -84,15 +84,13 @@ class M2M100TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token)
|
self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token)
|
||||||
|
|
||||||
def test_get_vocab(self):
|
def test_get_vocab(self):
|
||||||
vocab_keys = list(self.get_tokenizer().get_vocab().keys())
|
tokenizer = self.get_tokenizer()
|
||||||
|
vocab_keys = list(tokenizer.get_vocab().keys())
|
||||||
|
|
||||||
self.assertEqual(vocab_keys[0], "</s>")
|
self.assertEqual(vocab_keys[0], "</s>")
|
||||||
self.assertEqual(vocab_keys[1], "<unk>")
|
self.assertEqual(vocab_keys[1], "<unk>")
|
||||||
self.assertEqual(vocab_keys[-1], "<s>")
|
self.assertEqual(vocab_keys[-1], "<s>")
|
||||||
self.assertEqual(len(vocab_keys), 110)
|
self.assertEqual(len(vocab_keys), tokenizer.vocab_size + len(tokenizer.get_added_vocab()))
|
||||||
|
|
||||||
def test_vocab_size(self):
|
|
||||||
self.assertEqual(self.get_tokenizer().vocab_size, 117)
|
|
||||||
|
|
||||||
@unittest.skip("Skip this test while all models are still to be uploaded.")
|
@unittest.skip("Skip this test while all models are still to be uploaded.")
|
||||||
def test_pretrained_model_lists(self):
|
def test_pretrained_model_lists(self):
|
||||||
@@ -161,7 +159,10 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(self.tokenizer.get_lang_id("mr"), 128063)
|
self.assertEqual(self.tokenizer.get_lang_id("mr"), 128063)
|
||||||
|
|
||||||
def test_get_vocab(self):
|
def test_get_vocab(self):
|
||||||
self.assertIn(self.tokenizer.get_lang_token("en"), self.tokenizer.get_vocab())
|
vocab = self.tokenizer.get_vocab()
|
||||||
|
self.assertEqual(len(vocab), self.tokenizer.vocab_size)
|
||||||
|
self.assertEqual(vocab["<unk>"], 3)
|
||||||
|
self.assertIn(self.tokenizer.get_lang_token("en"), vocab)
|
||||||
|
|
||||||
def test_tokenizer_batch_encode_plus(self):
|
def test_tokenizer_batch_encode_plus(self):
|
||||||
self.tokenizer.src_lang = "en"
|
self.tokenizer.src_lang = "en"
|
||||||
|
|||||||
Reference in New Issue
Block a user