From ab108a0e31660647c19287b5057dc16282b2041a Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Tue, 25 Oct 2022 15:18:24 +0200 Subject: [PATCH] Add missing lang tokens in M2M100Tokenizer.get_vocab (#18416) --- src/transformers/models/m2m_100/tokenization_m2m_100.py | 2 +- tests/models/m2m_100/test_tokenization_m2m_100.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/m2m_100/tokenization_m2m_100.py b/src/transformers/models/m2m_100/tokenization_m2m_100.py index b67b82fb7a..fff596046e 100644 --- a/src/transformers/models/m2m_100/tokenization_m2m_100.py +++ b/src/transformers/models/m2m_100/tokenization_m2m_100.py @@ -280,7 +280,7 @@ class M2M100Tokenizer(PreTrainedTokenizer): return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens def get_vocab(self) -> Dict: - vocab = self.encoder.copy() + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} vocab.update(self.added_tokens_encoder) return vocab diff --git a/tests/models/m2m_100/test_tokenization_m2m_100.py b/tests/models/m2m_100/test_tokenization_m2m_100.py index ca8349d940..f8c5f5b7ba 100644 --- a/tests/models/m2m_100/test_tokenization_m2m_100.py +++ b/tests/models/m2m_100/test_tokenization_m2m_100.py @@ -89,7 +89,7 @@ class M2M100TokenizationTest(TokenizerTesterMixin, unittest.TestCase): self.assertEqual(vocab_keys[0], "") self.assertEqual(vocab_keys[1], "") self.assertEqual(vocab_keys[-1], "") - self.assertEqual(len(vocab_keys), 10) + self.assertEqual(len(vocab_keys), 110) def test_vocab_size(self): self.assertEqual(self.get_tokenizer().vocab_size, 117) @@ -160,6 +160,9 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase): self.assertEqual(self.tokenizer.get_lang_id("ro"), 128076) self.assertEqual(self.tokenizer.get_lang_id("mr"), 128063) + def test_get_vocab(self): + self.assertIn(self.tokenizer.get_lang_token("en"), self.tokenizer.get_vocab()) + def test_tokenizer_batch_encode_plus(self): self.tokenizer.src_lang = "en" ids = self.tokenizer.batch_encode_plus(self.src_text).input_ids[0]