* save total_vocab_size = vocab_size + user added tokens to speed up operation * updating length when added_tokens_decoder is set * add test len(tokenizer)
This commit is contained in:
@@ -480,6 +480,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||||||
|
|
||||||
self._added_tokens_decoder[index] = AddedToken(token) if isinstance(token, str) else token
|
self._added_tokens_decoder[index] = AddedToken(token) if isinstance(token, str) else token
|
||||||
self._added_tokens_encoder[str(token)] = index
|
self._added_tokens_encoder[str(token)] = index
|
||||||
|
self._update_total_vocab_size()
|
||||||
|
|
||||||
def get_added_vocab(self) -> Dict[str, int]:
|
def get_added_vocab(self) -> Dict[str, int]:
|
||||||
"""
|
"""
|
||||||
@@ -494,10 +495,17 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""
|
"""
|
||||||
Size of the full vocabulary with the added tokens. Counts the `keys` and not the `values` because otherwise if
|
Size of the full vocabulary with the added tokens.
|
||||||
there is a hole in the vocab, we will add tokenizers at a wrong index.
|
|
||||||
"""
|
"""
|
||||||
return len(set(self.get_vocab().keys()))
|
return self.total_vocab_size
|
||||||
|
|
||||||
|
def _update_total_vocab_size(self):
|
||||||
|
"""
|
||||||
|
Update the size of the full vocabulary with the added tokens. Counts the `keys` and not the `values` because
|
||||||
|
otherwise if there is a hole in the vocab, we will add tokenizers at a wrong index. This operation is slow and
|
||||||
|
is only updated when adding tokens.
|
||||||
|
"""
|
||||||
|
self.total_vocab_size = len(self.get_vocab())
|
||||||
|
|
||||||
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
|
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
|
||||||
"""
|
"""
|
||||||
@@ -574,6 +582,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||||||
logger.info(f"Adding {token} to the vocabulary")
|
logger.info(f"Adding {token} to the vocabulary")
|
||||||
|
|
||||||
self._update_trie()
|
self._update_trie()
|
||||||
|
self._update_total_vocab_size()
|
||||||
return added_tokens
|
return added_tokens
|
||||||
|
|
||||||
def _update_trie(self, unique_no_split_tokens: Optional[str] = []):
|
def _update_trie(self, unique_no_split_tokens: Optional[str] = []):
|
||||||
|
|||||||
@@ -284,3 +284,15 @@ class TokenizerUtilsTest(unittest.TestCase):
|
|||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
bert_tokenizer.save(os.path.join(tmpdirname, "tokenizer.json"))
|
bert_tokenizer.save(os.path.join(tmpdirname, "tokenizer.json"))
|
||||||
PreTrainedTokenizerFast(tokenizer_file=os.path.join(tmpdirname, "tokenizer.json"))
|
PreTrainedTokenizerFast(tokenizer_file=os.path.join(tmpdirname, "tokenizer.json"))
|
||||||
|
|
||||||
|
def test_len_tokenizer(self):
|
||||||
|
for tokenizer_class in [BertTokenizer, BertTokenizerFast]:
|
||||||
|
with self.subTest(f"{tokenizer_class}"):
|
||||||
|
tokenizer = tokenizer_class.from_pretrained("bert-base-uncased")
|
||||||
|
added_tokens_size = len(tokenizer.added_tokens_decoder)
|
||||||
|
self.assertEqual(len(tokenizer), tokenizer.vocab_size)
|
||||||
|
|
||||||
|
tokenizer.add_tokens(["<test_token>"])
|
||||||
|
self.assertEqual(len(tokenizer), tokenizer.vocab_size + 1)
|
||||||
|
self.assertEqual(len(tokenizer.added_tokens_decoder), added_tokens_size + 1)
|
||||||
|
self.assertEqual(len(tokenizer.added_tokens_encoder), added_tokens_size + 1)
|
||||||
|
|||||||
Reference in New Issue
Block a user