diff --git a/src/transformers/tokenization_albert.py b/src/transformers/tokenization_albert.py index 985f82c6fd..224636c997 100644 --- a/src/transformers/tokenization_albert.py +++ b/src/transformers/tokenization_albert.py @@ -114,6 +114,11 @@ class AlbertTokenizer(PreTrainedTokenizer): def vocab_size(self): return len(self.sp_model) + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + def __getstate__(self): state = self.__dict__.copy() state["sp_model"] = None diff --git a/src/transformers/tokenization_bert.py b/src/transformers/tokenization_bert.py index c76523d318..de74ee579c 100644 --- a/src/transformers/tokenization_bert.py +++ b/src/transformers/tokenization_bert.py @@ -195,6 +195,9 @@ class BertTokenizer(PreTrainedTokenizer): def vocab_size(self): return len(self.vocab) + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + def _tokenize(self, text): split_tokens = [] if self.do_basic_tokenize: diff --git a/src/transformers/tokenization_ctrl.py b/src/transformers/tokenization_ctrl.py index 1f2184f0a1..691824b92b 100644 --- a/src/transformers/tokenization_ctrl.py +++ b/src/transformers/tokenization_ctrl.py @@ -147,6 +147,9 @@ class CTRLTokenizer(PreTrainedTokenizer): def vocab_size(self): return len(self.encoder) + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + def bpe(self, token): if token in self.cache: return self.cache[token] diff --git a/src/transformers/tokenization_gpt2.py b/src/transformers/tokenization_gpt2.py index 5e8d9c7728..961797b97a 100644 --- a/src/transformers/tokenization_gpt2.py +++ b/src/transformers/tokenization_gpt2.py @@ -149,6 +149,9 @@ class GPT2Tokenizer(PreTrainedTokenizer): def vocab_size(self): return len(self.encoder) + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + def bpe(self, token): if token in self.cache: return self.cache[token] diff --git a/src/transformers/tokenization_openai.py b/src/transformers/tokenization_openai.py index e39c8df718..912ab852a7 100644 --- a/src/transformers/tokenization_openai.py +++ b/src/transformers/tokenization_openai.py @@ -125,6 +125,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): def vocab_size(self): return len(self.encoder) + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + def bpe(self, token): word = tuple(token[:-1]) + (token[-1] + "",) if token in self.cache: diff --git a/src/transformers/tokenization_t5.py b/src/transformers/tokenization_t5.py index 2196cc82e7..1aa5df38ad 100644 --- a/src/transformers/tokenization_t5.py +++ b/src/transformers/tokenization_t5.py @@ -119,6 +119,11 @@ class T5Tokenizer(PreTrainedTokenizer): def vocab_size(self): return self.sp_model.get_piece_size() + self._extra_ids + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + def __getstate__(self): state = self.__dict__.copy() state["sp_model"] = None diff --git a/src/transformers/tokenization_transfo_xl.py b/src/transformers/tokenization_transfo_xl.py index 66c93f1e19..f3f7ff3f31 100644 --- a/src/transformers/tokenization_transfo_xl.py +++ b/src/transformers/tokenization_transfo_xl.py @@ -273,6 +273,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer): def vocab_size(self): return len(self.idx2sym) + def get_vocab(self): + return dict(self.sym2idx, **self.added_tokens_encoder) + def _tokenize(self, line, add_eos=False, add_double_eos=False): line = line.strip() # convert to lower case diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index 879d1614b3..1f6d70c649 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -286,6 +286,10 @@ class PreTrainedTokenizer(object): """ Ids of all the additional special tokens in the vocabulary (list of integers). Log an error if used while not having been set. """ return self.convert_tokens_to_ids(self.additional_special_tokens) + def get_vocab(self): + """ Returns the vocabulary as a dict of {token: index} pairs. `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the vocab. """ + raise NotImplementedError() + def __init__(self, max_len=None, **kwargs): self._bos_token = None self._eos_token = None diff --git a/src/transformers/tokenization_xlm.py b/src/transformers/tokenization_xlm.py index 518f3dd7ff..93b8092abc 100644 --- a/src/transformers/tokenization_xlm.py +++ b/src/transformers/tokenization_xlm.py @@ -662,6 +662,9 @@ class XLMTokenizer(PreTrainedTokenizer): def vocab_size(self): return len(self.encoder) + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + def bpe(self, token): word = tuple(token[:-1]) + (token[-1] + "",) if token in self.cache: diff --git a/src/transformers/tokenization_xlm_roberta.py b/src/transformers/tokenization_xlm_roberta.py index ea39d945ae..1e903d8a2b 100644 --- a/src/transformers/tokenization_xlm_roberta.py +++ b/src/transformers/tokenization_xlm_roberta.py @@ -190,6 +190,11 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): def vocab_size(self): return len(self.sp_model) + len(self.fairseq_tokens_to_ids) + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + def _tokenize(self, text): return self.sp_model.EncodeAsPieces(text) diff --git a/src/transformers/tokenization_xlnet.py b/src/transformers/tokenization_xlnet.py index e3ebc71072..30fdfda22e 100644 --- a/src/transformers/tokenization_xlnet.py +++ b/src/transformers/tokenization_xlnet.py @@ -114,6 +114,11 @@ class XLNetTokenizer(PreTrainedTokenizer): def vocab_size(self): return len(self.sp_model) + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + def __getstate__(self): state = self.__dict__.copy() state["sp_model"] = None diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 27be1c9b84..a597d90f04 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -542,3 +542,23 @@ class TokenizerTesterMixin: print(new_tokenizer.init_kwargs) assert tokenizer.init_kwargs["random_argument"] is True assert new_tokenizer.init_kwargs["random_argument"] is False + + def test_get_vocab(self): + tokenizer = self.get_tokenizer() + vocab = tokenizer.get_vocab() + + self.assertIsInstance(vocab, dict) + self.assertEqual(len(vocab), len(tokenizer)) + + for word, ind in vocab.items(): + self.assertEqual(tokenizer.convert_tokens_to_ids(word), ind) + self.assertEqual(tokenizer.convert_ids_to_tokens(ind), word) + + tokenizer.add_tokens(["asdfasdfasdfasdf"]) + vocab = tokenizer.get_vocab() + self.assertIsInstance(vocab, dict) + self.assertEqual(len(vocab), len(tokenizer)) + + for word, ind in vocab.items(): + self.assertEqual(tokenizer.convert_tokens_to_ids(word), ind) + self.assertEqual(tokenizer.convert_ids_to_tokens(ind), word)