[Tentative] Moving slow tokenizer to the Trie world. (#13220)

* Moving slow tokenizer to the Trie world.

* Adding more docstrings to the Trie.

* Fixing doctest (incompatible wiht our format? )

* Update src/transformers/tokenization_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Adding a lot more comment into the internals of this algorithm.

* Cleaner doc.

* Fixing the namings.

* Update src/transformers/tokenization_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* quality.

* Fixing longest first match.

* Small improvements to cuts + more test + canine resistant test.

* Fixing fast test.

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Nicolas Patry
2021-09-09 17:26:16 +02:00
committed by GitHub
parent b8385d8a11
commit 3dd538c4d3
3 changed files with 256 additions and 76 deletions

View File

@@ -55,7 +55,7 @@ from transformers.testing_utils import (
require_torch,
slow,
)
from transformers.tokenization_utils import AddedToken
from transformers.tokenization_utils import AddedToken, Trie
if is_torch_available():
@@ -1659,6 +1659,34 @@ class TokenizerTesterMixin:
encoded_sequences_batch_padded_2[key],
)
@require_tokenizers
def test_added_token_are_matched_longest_first(self):
if not self.test_slow_tokenizer:
self.skipTest("This test is only for slow tokenizers")
return
tokenizers = self.get_tokenizers(fast=False)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
try:
tokenizer.add_tokens([AddedToken("extra_id_1")])
tokenizer.add_tokens([AddedToken("extra_id_100")])
except Exception:
# Canine cannot add tokens which are not codepoints
self.skipTest("Cannot add those Added tokens")
# XXX: This used to split on `extra_id_1` first we're matching
# longest first now.
tokens = tokenizer.tokenize("This is some extra_id_100")
self.assertIn("extra_id_100", tokens)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
tokenizer.add_tokens([AddedToken("extra_id_100")])
tokenizer.add_tokens([AddedToken("extra_id_1")])
tokens = tokenizer.tokenize("This is some extra_id_100")
self.assertIn("extra_id_100", tokens)
@require_tokenizers
def test_added_token_serializable(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
@@ -3489,3 +3517,21 @@ class TokenizerPushToHubTester(unittest.TestCase):
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
class TrieTest(unittest.TestCase):
def test_trie(self):
trie = Trie()
trie.add("Hello 友達")
self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {" ": {"": {"": {"": 1}}}}}}}}})
trie.add("Hello")
trie.data
self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"": {"": {"": 1}}}}}}}}})
def test_trie_split(self):
trie = Trie()
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS] This is a extra_id_100"])
trie.add("[CLS]")
trie.add("extra_id_1")
trie.add("extra_id_100")
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])