[Tentative] Moving slow tokenizer to the Trie world. (#13220)
* Moving slow tokenizer to the Trie world. * Adding more docstrings to the Trie. * Fixing doctest (incompatible wiht our format? ) * Update src/transformers/tokenization_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Adding a lot more comment into the internals of this algorithm. * Cleaner doc. * Fixing the namings. * Update src/transformers/tokenization_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * quality. * Fixing longest first match. * Small improvements to cuts + more test + canine resistant test. * Fixing fast test. Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -148,6 +148,8 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
|||||||
if len(token) > 1:
|
if len(token) > 1:
|
||||||
self.unique_no_split_tokens.append(token)
|
self.unique_no_split_tokens.append(token)
|
||||||
|
|
||||||
|
self._create_trie(self.unique_no_split_tokens)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def word_delimiter_token(self) -> str:
|
def word_delimiter_token(self) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -330,6 +332,8 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
|||||||
self._additional_special_tokens.append(AddedToken(token))
|
self._additional_special_tokens.append(AddedToken(token))
|
||||||
_insert_one_token_to_ordered_list(self.unique_no_split_tokens, token)
|
_insert_one_token_to_ordered_list(self.unique_no_split_tokens, token)
|
||||||
|
|
||||||
|
self._create_trie(self.unique_no_split_tokens)
|
||||||
|
|
||||||
return len(tokens_to_add)
|
return len(tokens_to_add)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -49,6 +49,173 @@ ADDED_TOKENS_FILE = "added_tokens.json"
|
|||||||
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
|
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
|
||||||
|
|
||||||
|
|
||||||
|
class Trie:
|
||||||
|
"""
|
||||||
|
Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass
|
||||||
|
Loose reference https://en.wikipedia.org/wiki/Trie
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.data = {}
|
||||||
|
|
||||||
|
def add(self, word: str):
|
||||||
|
"""
|
||||||
|
Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation.
|
||||||
|
The special key `""` is used to represent termination.
|
||||||
|
|
||||||
|
This function is idempotent, adding twice the same word will leave the trie unchanged
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> trie = Trie()
|
||||||
|
>>> trie.add("Hello 友達")
|
||||||
|
>>> trie.data
|
||||||
|
{"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}}
|
||||||
|
>>> trie.add("Hello")
|
||||||
|
>>> trie.data
|
||||||
|
{"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}}
|
||||||
|
"""
|
||||||
|
if not word:
|
||||||
|
# Prevent empty string
|
||||||
|
return
|
||||||
|
ref = self.data
|
||||||
|
for char in word:
|
||||||
|
ref[char] = char in ref and ref[char] or {}
|
||||||
|
ref = ref[char]
|
||||||
|
ref[""] = 1
|
||||||
|
|
||||||
|
def split(self, text: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Will look for the words added to the trie within `text`. Output is the original string splitted along the
|
||||||
|
boundaries of the words found.
|
||||||
|
|
||||||
|
This trie will match the longest possible word first !
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> trie = Trie()
|
||||||
|
>>> trie.split("[CLS] This is a extra_id_100")
|
||||||
|
["[CLS] This is a extra_id_100"]
|
||||||
|
>>> trie.add("[CLS]")
|
||||||
|
>>> trie.add("extra_id_1")
|
||||||
|
>>> trie.add("extra_id_100")
|
||||||
|
>>> trie.split("[CLS] This is a extra_id_100")
|
||||||
|
["[CLS]", " This is a ", "extra_id_100"]
|
||||||
|
"""
|
||||||
|
|
||||||
|
# indexes are counted left of the chars index.
|
||||||
|
# "hello", index 0, is left of h, index 1 is between h and e.
|
||||||
|
# index 5 is right of the "o".
|
||||||
|
|
||||||
|
# States are going to capture every possible start (indexes as above)
|
||||||
|
# as keys, and have as values, a pointer to the position in the trie
|
||||||
|
# where we're at. This is a partial match for now.
|
||||||
|
# This enables to keep track of multiple matches while we're iterating
|
||||||
|
# the string
|
||||||
|
# If the trie contains, "blowing", and "lower" and we encounter the
|
||||||
|
# string "blower", we need to split into ["b", "lower"].
|
||||||
|
# This is where we need to keep track of multiple possible starts.
|
||||||
|
states = {}
|
||||||
|
|
||||||
|
# This will contain every indices where we need
|
||||||
|
# to cut.
|
||||||
|
# We force to cut at offset 0 and len(text) (added later)
|
||||||
|
offsets = [0]
|
||||||
|
|
||||||
|
# This is used by the lookahead which needs to skip over
|
||||||
|
# some text where the full match exceeded the place in the initial
|
||||||
|
# for loop
|
||||||
|
skip = None
|
||||||
|
# Main loop, Giving this algorithm O(n) complexity
|
||||||
|
for current, current_char in enumerate(text):
|
||||||
|
if skip and current < skip:
|
||||||
|
# Prevents the lookahead for matching twice
|
||||||
|
# like extra_id_100 and id_100
|
||||||
|
continue
|
||||||
|
|
||||||
|
# This will track every state
|
||||||
|
# that stop matching, we need to stop tracking them.
|
||||||
|
# If we look at "lowball", we're going to match "l" (add it to states), "o", "w", then
|
||||||
|
# fail on "b", we need to remove 0 from the valid states.
|
||||||
|
to_remove = set()
|
||||||
|
# Whenever we found a match, we need to drop everything
|
||||||
|
# this is a greedy algorithm, it will match on the first found token
|
||||||
|
reset = False
|
||||||
|
|
||||||
|
# In this case, we already have partial matches (But unfinished)
|
||||||
|
for start, trie_pointer in states.items():
|
||||||
|
if current_char in trie_pointer:
|
||||||
|
# The current character being looked at has a match within the trie
|
||||||
|
# update the pointer (it will be stored back into states later).
|
||||||
|
trie_pointer = trie_pointer[current_char]
|
||||||
|
if "" in trie_pointer:
|
||||||
|
# This is a final match, we need to reset and
|
||||||
|
# store the results in `offsets`.
|
||||||
|
|
||||||
|
# Lookahead to match longest first
|
||||||
|
# Important in case of extra_id_1 vs extra_id_100
|
||||||
|
lookahead_index = current + 1
|
||||||
|
end = current + 1
|
||||||
|
next_char = text[lookahead_index] if lookahead_index < len(text) else None
|
||||||
|
while next_char in trie_pointer:
|
||||||
|
trie_pointer = trie_pointer[next_char]
|
||||||
|
lookahead_index += 1
|
||||||
|
if "" in trie_pointer:
|
||||||
|
end = lookahead_index
|
||||||
|
skip = lookahead_index
|
||||||
|
|
||||||
|
if lookahead_index == len(text):
|
||||||
|
# End of string
|
||||||
|
break
|
||||||
|
next_char = text[lookahead_index]
|
||||||
|
# End lookahead
|
||||||
|
|
||||||
|
# Storing and resetting
|
||||||
|
offsets.append(start)
|
||||||
|
offsets.append(end)
|
||||||
|
reset = True
|
||||||
|
|
||||||
|
# Storing back the new pointer into the states.
|
||||||
|
# Partial matches got longer by one.
|
||||||
|
states[start] = trie_pointer
|
||||||
|
else:
|
||||||
|
# The new character has not match in the trie, we need
|
||||||
|
# to stop keeping track of this partial match.
|
||||||
|
# We can't do it directly within the loop because of how
|
||||||
|
# python iteration works
|
||||||
|
to_remove.add(start)
|
||||||
|
|
||||||
|
# Either clearing the full start (we found a real match)
|
||||||
|
# Or clearing only the partial matches that didn't work.
|
||||||
|
if reset:
|
||||||
|
states = {}
|
||||||
|
else:
|
||||||
|
for start in to_remove:
|
||||||
|
del states[start]
|
||||||
|
|
||||||
|
# If this character is a starting character within the trie
|
||||||
|
# start keeping track of this partial match.
|
||||||
|
if current_char in self.data:
|
||||||
|
states[current] = self.data[current_char]
|
||||||
|
|
||||||
|
# We have all the offsets now, we just need to do the actual splitting.
|
||||||
|
# We need to eventually add the first part of the string and the eventual
|
||||||
|
# last part.
|
||||||
|
offsets.append(len(text))
|
||||||
|
tokens = []
|
||||||
|
start = 0
|
||||||
|
for end in offsets:
|
||||||
|
if start == end:
|
||||||
|
# This might happen if there's a match at index 0
|
||||||
|
# we're also preventing zero-width cuts in case of two
|
||||||
|
# consecutive matches
|
||||||
|
continue
|
||||||
|
tokens.append(text[start:end])
|
||||||
|
start = end
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
def _is_whitespace(char):
|
def _is_whitespace(char):
|
||||||
"""Checks whether `char` is a whitespace character."""
|
"""Checks whether `char` is a whitespace character."""
|
||||||
# \t, \n, and \r are technically control characters but we treat them
|
# \t, \n, and \r are technically control characters but we treat them
|
||||||
@@ -135,6 +302,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||||||
self.added_tokens_encoder: Dict[str, int] = {}
|
self.added_tokens_encoder: Dict[str, int] = {}
|
||||||
self.added_tokens_decoder: Dict[int, str] = {}
|
self.added_tokens_decoder: Dict[int, str] = {}
|
||||||
self.unique_no_split_tokens: List[str] = []
|
self.unique_no_split_tokens: List[str] = []
|
||||||
|
self.tokens_trie = Trie()
|
||||||
|
|
||||||
self._decode_use_source_tokenizer = False
|
self._decode_use_source_tokenizer = False
|
||||||
|
|
||||||
@@ -223,9 +391,19 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||||||
_insert_one_token_to_ordered_list(self.unique_no_split_tokens, tokens_to_add[0])
|
_insert_one_token_to_ordered_list(self.unique_no_split_tokens, tokens_to_add[0])
|
||||||
else:
|
else:
|
||||||
self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(tokens_to_add)))
|
self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(tokens_to_add)))
|
||||||
|
self._create_trie(self.unique_no_split_tokens)
|
||||||
|
|
||||||
return len(tokens_to_add)
|
return len(tokens_to_add)
|
||||||
|
|
||||||
|
def _create_trie(self, unique_no_split_tokens):
|
||||||
|
trie = Trie()
|
||||||
|
for token in unique_no_split_tokens:
|
||||||
|
if hasattr(self, "do_lower_case") and self.do_lower_case and token not in self.all_special_tokens:
|
||||||
|
trie.add(token.lower())
|
||||||
|
else:
|
||||||
|
trie.add(token)
|
||||||
|
self.tokens_trie = trie
|
||||||
|
|
||||||
def num_special_tokens_to_add(self, pair: bool = False) -> int:
|
def num_special_tokens_to_add(self, pair: bool = False) -> int:
|
||||||
"""
|
"""
|
||||||
Returns the number of added tokens when encoding a sequence with special tokens.
|
Returns the number of added tokens when encoding a sequence with special tokens.
|
||||||
@@ -279,87 +457,39 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||||||
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
|
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
|
||||||
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
|
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
|
||||||
|
|
||||||
def split_on_token(tok, text):
|
no_split_token = set(self.unique_no_split_tokens)
|
||||||
result = []
|
tokens = self.tokens_trie.split(text)
|
||||||
tok_extended = all_special_tokens_extended.get(tok, None)
|
# ["This is something", "<special_token_1>", " else"]
|
||||||
split_text = text.split(tok)
|
for i, token in enumerate(tokens):
|
||||||
full_word = ""
|
if token in no_split_token:
|
||||||
for i, sub_text in enumerate(split_text):
|
tok_extended = all_special_tokens_extended.get(token, None)
|
||||||
# AddedToken can control whitespace stripping around them.
|
left = tokens[i - 1] if i > 0 else None
|
||||||
# We use them for GPT2 and Roberta to have different behavior depending on the special token
|
right = tokens[i + 1] if i < len(tokens) - 1 else None
|
||||||
# Cf. https://github.com/huggingface/transformers/pull/2778
|
|
||||||
# and https://github.com/huggingface/transformers/issues/3788
|
|
||||||
if isinstance(tok_extended, AddedToken):
|
if isinstance(tok_extended, AddedToken):
|
||||||
if tok_extended.single_word:
|
if tok_extended.rstrip and right:
|
||||||
# Try to avoid splitting on token
|
|
||||||
if (
|
|
||||||
i < len(split_text) - 1
|
|
||||||
and not _is_end_of_word(sub_text)
|
|
||||||
and not _is_start_of_word(split_text[i + 1])
|
|
||||||
):
|
|
||||||
# Don't extract the special token
|
|
||||||
full_word += sub_text + tok
|
|
||||||
elif full_word:
|
|
||||||
full_word += sub_text
|
|
||||||
result.append(full_word)
|
|
||||||
full_word = ""
|
|
||||||
continue
|
|
||||||
# Strip white spaces on the right
|
|
||||||
if tok_extended.rstrip and i > 0:
|
|
||||||
# A bit counter-intuitive but we strip the left of the string
|
# A bit counter-intuitive but we strip the left of the string
|
||||||
# since tok_extended.rstrip means the special token is eating all white spaces on its right
|
# since tok_extended.rstrip means the special token is eating all white spaces on its right
|
||||||
sub_text = sub_text.lstrip()
|
tokens[i + 1] = right.lstrip()
|
||||||
# Strip white spaces on the left
|
# Strip white spaces on the left
|
||||||
if tok_extended.lstrip and i < len(split_text) - 1:
|
if tok_extended.lstrip and left:
|
||||||
sub_text = sub_text.rstrip() # Opposite here
|
tokens[i - 1] = left.rstrip() # Opposite here
|
||||||
else:
|
else:
|
||||||
# We strip left and right by default
|
# We strip left and right by default
|
||||||
if i < len(split_text) - 1:
|
if right:
|
||||||
sub_text = sub_text.rstrip()
|
tokens[i + 1] = right.lstrip()
|
||||||
if i > 0:
|
if left:
|
||||||
sub_text = sub_text.lstrip()
|
tokens[i - 1] = left.rstrip()
|
||||||
|
# ["This is something", "<special_token_1>", "else"]
|
||||||
if i == 0 and not sub_text:
|
tokenized_text = []
|
||||||
result.append(tok)
|
for token in tokens:
|
||||||
elif i == len(split_text) - 1:
|
# Need to skip eventual empty (fully stripped) tokens
|
||||||
if sub_text:
|
if not token:
|
||||||
result.append(sub_text)
|
continue
|
||||||
else:
|
if token in no_split_token:
|
||||||
pass
|
tokenized_text.append(token)
|
||||||
else:
|
else:
|
||||||
if sub_text:
|
tokenized_text.extend(self._tokenize(token))
|
||||||
result.append(sub_text)
|
# ["This", " is", " something", "<special_token_1>", "else"]
|
||||||
result.append(tok)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def split_on_tokens(tok_list, text):
|
|
||||||
if not text.strip():
|
|
||||||
return []
|
|
||||||
if not tok_list:
|
|
||||||
return self._tokenize(text)
|
|
||||||
|
|
||||||
tokenized_text = []
|
|
||||||
text_list = [text]
|
|
||||||
for tok in tok_list:
|
|
||||||
tokenized_text = []
|
|
||||||
for sub_text in text_list:
|
|
||||||
if sub_text not in self.unique_no_split_tokens:
|
|
||||||
tokenized_text.extend(split_on_token(tok, sub_text))
|
|
||||||
else:
|
|
||||||
tokenized_text.append(sub_text)
|
|
||||||
text_list = tokenized_text
|
|
||||||
|
|
||||||
return list(
|
|
||||||
itertools.chain.from_iterable(
|
|
||||||
(
|
|
||||||
self._tokenize(token) if token not in self.unique_no_split_tokens else [token]
|
|
||||||
for token in tokenized_text
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
no_split_token = self.unique_no_split_tokens
|
|
||||||
tokenized_text = split_on_tokens(no_split_token, text)
|
|
||||||
return tokenized_text
|
return tokenized_text
|
||||||
|
|
||||||
def _tokenize(self, text, **kwargs):
|
def _tokenize(self, text, **kwargs):
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ from transformers.testing_utils import (
|
|||||||
require_torch,
|
require_torch,
|
||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
from transformers.tokenization_utils import AddedToken
|
from transformers.tokenization_utils import AddedToken, Trie
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -1659,6 +1659,34 @@ class TokenizerTesterMixin:
|
|||||||
encoded_sequences_batch_padded_2[key],
|
encoded_sequences_batch_padded_2[key],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_tokenizers
|
||||||
|
def test_added_token_are_matched_longest_first(self):
|
||||||
|
if not self.test_slow_tokenizer:
|
||||||
|
self.skipTest("This test is only for slow tokenizers")
|
||||||
|
return
|
||||||
|
tokenizers = self.get_tokenizers(fast=False)
|
||||||
|
for tokenizer in tokenizers:
|
||||||
|
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||||
|
try:
|
||||||
|
tokenizer.add_tokens([AddedToken("extra_id_1")])
|
||||||
|
tokenizer.add_tokens([AddedToken("extra_id_100")])
|
||||||
|
except Exception:
|
||||||
|
# Canine cannot add tokens which are not codepoints
|
||||||
|
self.skipTest("Cannot add those Added tokens")
|
||||||
|
|
||||||
|
# XXX: This used to split on `extra_id_1` first we're matching
|
||||||
|
# longest first now.
|
||||||
|
tokens = tokenizer.tokenize("This is some extra_id_100")
|
||||||
|
self.assertIn("extra_id_100", tokens)
|
||||||
|
|
||||||
|
for tokenizer in tokenizers:
|
||||||
|
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||||
|
tokenizer.add_tokens([AddedToken("extra_id_100")])
|
||||||
|
tokenizer.add_tokens([AddedToken("extra_id_1")])
|
||||||
|
|
||||||
|
tokens = tokenizer.tokenize("This is some extra_id_100")
|
||||||
|
self.assertIn("extra_id_100", tokens)
|
||||||
|
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
def test_added_token_serializable(self):
|
def test_added_token_serializable(self):
|
||||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||||
@@ -3489,3 +3517,21 @@ class TokenizerPushToHubTester(unittest.TestCase):
|
|||||||
|
|
||||||
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
|
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
|
||||||
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
||||||
|
|
||||||
|
|
||||||
|
class TrieTest(unittest.TestCase):
|
||||||
|
def test_trie(self):
|
||||||
|
trie = Trie()
|
||||||
|
trie.add("Hello 友達")
|
||||||
|
self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}})
|
||||||
|
trie.add("Hello")
|
||||||
|
trie.data
|
||||||
|
self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}})
|
||||||
|
|
||||||
|
def test_trie_split(self):
|
||||||
|
trie = Trie()
|
||||||
|
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS] This is a extra_id_100"])
|
||||||
|
trie.add("[CLS]")
|
||||||
|
trie.add("extra_id_1")
|
||||||
|
trie.add("extra_id_100")
|
||||||
|
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])
|
||||||
|
|||||||
Reference in New Issue
Block a user