[Tentative] Moving slow tokenizer to the Trie world. (#13220)
* Moving slow tokenizer to the Trie world. * Adding more docstrings to the Trie. * Fixing doctest (incompatible wiht our format? ) * Update src/transformers/tokenization_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Adding a lot more comment into the internals of this algorithm. * Cleaner doc. * Fixing the namings. * Update src/transformers/tokenization_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * quality. * Fixing longest first match. * Small improvements to cuts + more test + canine resistant test. * Fixing fast test. Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -55,7 +55,7 @@ from transformers.testing_utils import (
|
||||
require_torch,
|
||||
slow,
|
||||
)
|
||||
from transformers.tokenization_utils import AddedToken
|
||||
from transformers.tokenization_utils import AddedToken, Trie
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -1659,6 +1659,34 @@ class TokenizerTesterMixin:
|
||||
encoded_sequences_batch_padded_2[key],
|
||||
)
|
||||
|
||||
@require_tokenizers
|
||||
def test_added_token_are_matched_longest_first(self):
|
||||
if not self.test_slow_tokenizer:
|
||||
self.skipTest("This test is only for slow tokenizers")
|
||||
return
|
||||
tokenizers = self.get_tokenizers(fast=False)
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
try:
|
||||
tokenizer.add_tokens([AddedToken("extra_id_1")])
|
||||
tokenizer.add_tokens([AddedToken("extra_id_100")])
|
||||
except Exception:
|
||||
# Canine cannot add tokens which are not codepoints
|
||||
self.skipTest("Cannot add those Added tokens")
|
||||
|
||||
# XXX: This used to split on `extra_id_1` first we're matching
|
||||
# longest first now.
|
||||
tokens = tokenizer.tokenize("This is some extra_id_100")
|
||||
self.assertIn("extra_id_100", tokens)
|
||||
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
tokenizer.add_tokens([AddedToken("extra_id_100")])
|
||||
tokenizer.add_tokens([AddedToken("extra_id_1")])
|
||||
|
||||
tokens = tokenizer.tokenize("This is some extra_id_100")
|
||||
self.assertIn("extra_id_100", tokens)
|
||||
|
||||
@require_tokenizers
|
||||
def test_added_token_serializable(self):
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||
@@ -3489,3 +3517,21 @@ class TokenizerPushToHubTester(unittest.TestCase):
|
||||
|
||||
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
|
||||
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
||||
|
||||
|
||||
class TrieTest(unittest.TestCase):
|
||||
def test_trie(self):
|
||||
trie = Trie()
|
||||
trie.add("Hello 友達")
|
||||
self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}})
|
||||
trie.add("Hello")
|
||||
trie.data
|
||||
self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}})
|
||||
|
||||
def test_trie_split(self):
|
||||
trie = Trie()
|
||||
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS] This is a extra_id_100"])
|
||||
trie.add("[CLS]")
|
||||
trie.add("extra_id_1")
|
||||
trie.add("extra_id_100")
|
||||
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])
|
||||
|
||||
Reference in New Issue
Block a user