From ca905ba28e027bbd27543a43fc7ff7f91b9ce9c9 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Wed, 8 Feb 2023 15:19:06 +0100 Subject: [PATCH] Exclude the madeup words from M2M100Tokenizer.vocab_size (#20976) --- .../models/m2m_100/tokenization_m2m_100.py | 2 +- tests/models/m2m_100/test_tokenization_m2m_100.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/m2m_100/tokenization_m2m_100.py b/src/transformers/models/m2m_100/tokenization_m2m_100.py index 984d05cd58..dcfa51555f 100644 --- a/src/transformers/models/m2m_100/tokenization_m2m_100.py +++ b/src/transformers/models/m2m_100/tokenization_m2m_100.py @@ -193,7 +193,7 @@ class M2M100Tokenizer(PreTrainedTokenizer): @property def vocab_size(self) -> int: - return len(self.encoder) + len(self.lang_token_to_id) + self.num_madeup_words + return len(self.encoder) + len(self.lang_token_to_id) @property def src_lang(self) -> str: diff --git a/tests/models/m2m_100/test_tokenization_m2m_100.py b/tests/models/m2m_100/test_tokenization_m2m_100.py index 626ef29412..6970833541 100644 --- a/tests/models/m2m_100/test_tokenization_m2m_100.py +++ b/tests/models/m2m_100/test_tokenization_m2m_100.py @@ -84,15 +84,13 @@ class M2M100TokenizationTest(TokenizerTesterMixin, unittest.TestCase): self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token) def test_get_vocab(self): - vocab_keys = list(self.get_tokenizer().get_vocab().keys()) + tokenizer = self.get_tokenizer() + vocab_keys = list(tokenizer.get_vocab().keys()) self.assertEqual(vocab_keys[0], "") self.assertEqual(vocab_keys[1], "") self.assertEqual(vocab_keys[-1], "") - self.assertEqual(len(vocab_keys), 110) - - def test_vocab_size(self): - self.assertEqual(self.get_tokenizer().vocab_size, 117) + self.assertEqual(len(vocab_keys), tokenizer.vocab_size + len(tokenizer.get_added_vocab())) @unittest.skip("Skip this test while all models are still to be uploaded.") def test_pretrained_model_lists(self): @@ -161,7 +159,10 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase): self.assertEqual(self.tokenizer.get_lang_id("mr"), 128063) def test_get_vocab(self): - self.assertIn(self.tokenizer.get_lang_token("en"), self.tokenizer.get_vocab()) + vocab = self.tokenizer.get_vocab() + self.assertEqual(len(vocab), self.tokenizer.vocab_size) + self.assertEqual(vocab[""], 3) + self.assertIn(self.tokenizer.get_lang_token("en"), vocab) def test_tokenizer_batch_encode_plus(self): self.tokenizer.src_lang = "en"