[Tentative] Moving slow tokenizer to the Trie world. (#13220)
* Moving slow tokenizer to the Trie world. * Adding more docstrings to the Trie. * Fixing doctest (incompatible wiht our format? ) * Update src/transformers/tokenization_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Adding a lot more comment into the internals of this algorithm. * Cleaner doc. * Fixing the namings. * Update src/transformers/tokenization_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * quality. * Fixing longest first match. * Small improvements to cuts + more test + canine resistant test. * Fixing fast test. Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -148,6 +148,8 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
||||
if len(token) > 1:
|
||||
self.unique_no_split_tokens.append(token)
|
||||
|
||||
self._create_trie(self.unique_no_split_tokens)
|
||||
|
||||
@property
|
||||
def word_delimiter_token(self) -> str:
|
||||
"""
|
||||
@@ -330,6 +332,8 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
||||
self._additional_special_tokens.append(AddedToken(token))
|
||||
_insert_one_token_to_ordered_list(self.unique_no_split_tokens, token)
|
||||
|
||||
self._create_trie(self.unique_no_split_tokens)
|
||||
|
||||
return len(tokens_to_add)
|
||||
|
||||
|
||||
|
||||
@@ -49,6 +49,173 @@ ADDED_TOKENS_FILE = "added_tokens.json"
|
||||
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
|
||||
|
||||
|
||||
class Trie:
|
||||
"""
|
||||
Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass
|
||||
Loose reference https://en.wikipedia.org/wiki/Trie
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.data = {}
|
||||
|
||||
def add(self, word: str):
|
||||
"""
|
||||
Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation.
|
||||
The special key `""` is used to represent termination.
|
||||
|
||||
This function is idempotent, adding twice the same word will leave the trie unchanged
|
||||
|
||||
Example::
|
||||
|
||||
>>> trie = Trie()
|
||||
>>> trie.add("Hello 友達")
|
||||
>>> trie.data
|
||||
{"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}}
|
||||
>>> trie.add("Hello")
|
||||
>>> trie.data
|
||||
{"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}}
|
||||
"""
|
||||
if not word:
|
||||
# Prevent empty string
|
||||
return
|
||||
ref = self.data
|
||||
for char in word:
|
||||
ref[char] = char in ref and ref[char] or {}
|
||||
ref = ref[char]
|
||||
ref[""] = 1
|
||||
|
||||
def split(self, text: str) -> List[str]:
|
||||
"""
|
||||
Will look for the words added to the trie within `text`. Output is the original string splitted along the
|
||||
boundaries of the words found.
|
||||
|
||||
This trie will match the longest possible word first !
|
||||
|
||||
Example::
|
||||
|
||||
>>> trie = Trie()
|
||||
>>> trie.split("[CLS] This is a extra_id_100")
|
||||
["[CLS] This is a extra_id_100"]
|
||||
>>> trie.add("[CLS]")
|
||||
>>> trie.add("extra_id_1")
|
||||
>>> trie.add("extra_id_100")
|
||||
>>> trie.split("[CLS] This is a extra_id_100")
|
||||
["[CLS]", " This is a ", "extra_id_100"]
|
||||
"""
|
||||
|
||||
# indexes are counted left of the chars index.
|
||||
# "hello", index 0, is left of h, index 1 is between h and e.
|
||||
# index 5 is right of the "o".
|
||||
|
||||
# States are going to capture every possible start (indexes as above)
|
||||
# as keys, and have as values, a pointer to the position in the trie
|
||||
# where we're at. This is a partial match for now.
|
||||
# This enables to keep track of multiple matches while we're iterating
|
||||
# the string
|
||||
# If the trie contains, "blowing", and "lower" and we encounter the
|
||||
# string "blower", we need to split into ["b", "lower"].
|
||||
# This is where we need to keep track of multiple possible starts.
|
||||
states = {}
|
||||
|
||||
# This will contain every indices where we need
|
||||
# to cut.
|
||||
# We force to cut at offset 0 and len(text) (added later)
|
||||
offsets = [0]
|
||||
|
||||
# This is used by the lookahead which needs to skip over
|
||||
# some text where the full match exceeded the place in the initial
|
||||
# for loop
|
||||
skip = None
|
||||
# Main loop, Giving this algorithm O(n) complexity
|
||||
for current, current_char in enumerate(text):
|
||||
if skip and current < skip:
|
||||
# Prevents the lookahead for matching twice
|
||||
# like extra_id_100 and id_100
|
||||
continue
|
||||
|
||||
# This will track every state
|
||||
# that stop matching, we need to stop tracking them.
|
||||
# If we look at "lowball", we're going to match "l" (add it to states), "o", "w", then
|
||||
# fail on "b", we need to remove 0 from the valid states.
|
||||
to_remove = set()
|
||||
# Whenever we found a match, we need to drop everything
|
||||
# this is a greedy algorithm, it will match on the first found token
|
||||
reset = False
|
||||
|
||||
# In this case, we already have partial matches (But unfinished)
|
||||
for start, trie_pointer in states.items():
|
||||
if current_char in trie_pointer:
|
||||
# The current character being looked at has a match within the trie
|
||||
# update the pointer (it will be stored back into states later).
|
||||
trie_pointer = trie_pointer[current_char]
|
||||
if "" in trie_pointer:
|
||||
# This is a final match, we need to reset and
|
||||
# store the results in `offsets`.
|
||||
|
||||
# Lookahead to match longest first
|
||||
# Important in case of extra_id_1 vs extra_id_100
|
||||
lookahead_index = current + 1
|
||||
end = current + 1
|
||||
next_char = text[lookahead_index] if lookahead_index < len(text) else None
|
||||
while next_char in trie_pointer:
|
||||
trie_pointer = trie_pointer[next_char]
|
||||
lookahead_index += 1
|
||||
if "" in trie_pointer:
|
||||
end = lookahead_index
|
||||
skip = lookahead_index
|
||||
|
||||
if lookahead_index == len(text):
|
||||
# End of string
|
||||
break
|
||||
next_char = text[lookahead_index]
|
||||
# End lookahead
|
||||
|
||||
# Storing and resetting
|
||||
offsets.append(start)
|
||||
offsets.append(end)
|
||||
reset = True
|
||||
|
||||
# Storing back the new pointer into the states.
|
||||
# Partial matches got longer by one.
|
||||
states[start] = trie_pointer
|
||||
else:
|
||||
# The new character has not match in the trie, we need
|
||||
# to stop keeping track of this partial match.
|
||||
# We can't do it directly within the loop because of how
|
||||
# python iteration works
|
||||
to_remove.add(start)
|
||||
|
||||
# Either clearing the full start (we found a real match)
|
||||
# Or clearing only the partial matches that didn't work.
|
||||
if reset:
|
||||
states = {}
|
||||
else:
|
||||
for start in to_remove:
|
||||
del states[start]
|
||||
|
||||
# If this character is a starting character within the trie
|
||||
# start keeping track of this partial match.
|
||||
if current_char in self.data:
|
||||
states[current] = self.data[current_char]
|
||||
|
||||
# We have all the offsets now, we just need to do the actual splitting.
|
||||
# We need to eventually add the first part of the string and the eventual
|
||||
# last part.
|
||||
offsets.append(len(text))
|
||||
tokens = []
|
||||
start = 0
|
||||
for end in offsets:
|
||||
if start == end:
|
||||
# This might happen if there's a match at index 0
|
||||
# we're also preventing zero-width cuts in case of two
|
||||
# consecutive matches
|
||||
continue
|
||||
tokens.append(text[start:end])
|
||||
start = end
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
def _is_whitespace(char):
|
||||
"""Checks whether `char` is a whitespace character."""
|
||||
# \t, \n, and \r are technically control characters but we treat them
|
||||
@@ -135,6 +302,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
||||
self.added_tokens_encoder: Dict[str, int] = {}
|
||||
self.added_tokens_decoder: Dict[int, str] = {}
|
||||
self.unique_no_split_tokens: List[str] = []
|
||||
self.tokens_trie = Trie()
|
||||
|
||||
self._decode_use_source_tokenizer = False
|
||||
|
||||
@@ -223,9 +391,19 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
||||
_insert_one_token_to_ordered_list(self.unique_no_split_tokens, tokens_to_add[0])
|
||||
else:
|
||||
self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(tokens_to_add)))
|
||||
self._create_trie(self.unique_no_split_tokens)
|
||||
|
||||
return len(tokens_to_add)
|
||||
|
||||
def _create_trie(self, unique_no_split_tokens):
|
||||
trie = Trie()
|
||||
for token in unique_no_split_tokens:
|
||||
if hasattr(self, "do_lower_case") and self.do_lower_case and token not in self.all_special_tokens:
|
||||
trie.add(token.lower())
|
||||
else:
|
||||
trie.add(token)
|
||||
self.tokens_trie = trie
|
||||
|
||||
def num_special_tokens_to_add(self, pair: bool = False) -> int:
|
||||
"""
|
||||
Returns the number of added tokens when encoding a sequence with special tokens.
|
||||
@@ -279,87 +457,39 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
||||
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
|
||||
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
|
||||
|
||||
def split_on_token(tok, text):
|
||||
result = []
|
||||
tok_extended = all_special_tokens_extended.get(tok, None)
|
||||
split_text = text.split(tok)
|
||||
full_word = ""
|
||||
for i, sub_text in enumerate(split_text):
|
||||
# AddedToken can control whitespace stripping around them.
|
||||
# We use them for GPT2 and Roberta to have different behavior depending on the special token
|
||||
# Cf. https://github.com/huggingface/transformers/pull/2778
|
||||
# and https://github.com/huggingface/transformers/issues/3788
|
||||
no_split_token = set(self.unique_no_split_tokens)
|
||||
tokens = self.tokens_trie.split(text)
|
||||
# ["This is something", "<special_token_1>", " else"]
|
||||
for i, token in enumerate(tokens):
|
||||
if token in no_split_token:
|
||||
tok_extended = all_special_tokens_extended.get(token, None)
|
||||
left = tokens[i - 1] if i > 0 else None
|
||||
right = tokens[i + 1] if i < len(tokens) - 1 else None
|
||||
if isinstance(tok_extended, AddedToken):
|
||||
if tok_extended.single_word:
|
||||
# Try to avoid splitting on token
|
||||
if (
|
||||
i < len(split_text) - 1
|
||||
and not _is_end_of_word(sub_text)
|
||||
and not _is_start_of_word(split_text[i + 1])
|
||||
):
|
||||
# Don't extract the special token
|
||||
full_word += sub_text + tok
|
||||
elif full_word:
|
||||
full_word += sub_text
|
||||
result.append(full_word)
|
||||
full_word = ""
|
||||
continue
|
||||
# Strip white spaces on the right
|
||||
if tok_extended.rstrip and i > 0:
|
||||
if tok_extended.rstrip and right:
|
||||
# A bit counter-intuitive but we strip the left of the string
|
||||
# since tok_extended.rstrip means the special token is eating all white spaces on its right
|
||||
sub_text = sub_text.lstrip()
|
||||
tokens[i + 1] = right.lstrip()
|
||||
# Strip white spaces on the left
|
||||
if tok_extended.lstrip and i < len(split_text) - 1:
|
||||
sub_text = sub_text.rstrip() # Opposite here
|
||||
if tok_extended.lstrip and left:
|
||||
tokens[i - 1] = left.rstrip() # Opposite here
|
||||
else:
|
||||
# We strip left and right by default
|
||||
if i < len(split_text) - 1:
|
||||
sub_text = sub_text.rstrip()
|
||||
if i > 0:
|
||||
sub_text = sub_text.lstrip()
|
||||
|
||||
if i == 0 and not sub_text:
|
||||
result.append(tok)
|
||||
elif i == len(split_text) - 1:
|
||||
if sub_text:
|
||||
result.append(sub_text)
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
if sub_text:
|
||||
result.append(sub_text)
|
||||
result.append(tok)
|
||||
return result
|
||||
|
||||
def split_on_tokens(tok_list, text):
|
||||
if not text.strip():
|
||||
return []
|
||||
if not tok_list:
|
||||
return self._tokenize(text)
|
||||
|
||||
if right:
|
||||
tokens[i + 1] = right.lstrip()
|
||||
if left:
|
||||
tokens[i - 1] = left.rstrip()
|
||||
# ["This is something", "<special_token_1>", "else"]
|
||||
tokenized_text = []
|
||||
text_list = [text]
|
||||
for tok in tok_list:
|
||||
tokenized_text = []
|
||||
for sub_text in text_list:
|
||||
if sub_text not in self.unique_no_split_tokens:
|
||||
tokenized_text.extend(split_on_token(tok, sub_text))
|
||||
for token in tokens:
|
||||
# Need to skip eventual empty (fully stripped) tokens
|
||||
if not token:
|
||||
continue
|
||||
if token in no_split_token:
|
||||
tokenized_text.append(token)
|
||||
else:
|
||||
tokenized_text.append(sub_text)
|
||||
text_list = tokenized_text
|
||||
|
||||
return list(
|
||||
itertools.chain.from_iterable(
|
||||
(
|
||||
self._tokenize(token) if token not in self.unique_no_split_tokens else [token]
|
||||
for token in tokenized_text
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
no_split_token = self.unique_no_split_tokens
|
||||
tokenized_text = split_on_tokens(no_split_token, text)
|
||||
tokenized_text.extend(self._tokenize(token))
|
||||
# ["This", " is", " something", "<special_token_1>", "else"]
|
||||
return tokenized_text
|
||||
|
||||
def _tokenize(self, text, **kwargs):
|
||||
|
||||
@@ -55,7 +55,7 @@ from transformers.testing_utils import (
|
||||
require_torch,
|
||||
slow,
|
||||
)
|
||||
from transformers.tokenization_utils import AddedToken
|
||||
from transformers.tokenization_utils import AddedToken, Trie
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -1659,6 +1659,34 @@ class TokenizerTesterMixin:
|
||||
encoded_sequences_batch_padded_2[key],
|
||||
)
|
||||
|
||||
@require_tokenizers
|
||||
def test_added_token_are_matched_longest_first(self):
|
||||
if not self.test_slow_tokenizer:
|
||||
self.skipTest("This test is only for slow tokenizers")
|
||||
return
|
||||
tokenizers = self.get_tokenizers(fast=False)
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
try:
|
||||
tokenizer.add_tokens([AddedToken("extra_id_1")])
|
||||
tokenizer.add_tokens([AddedToken("extra_id_100")])
|
||||
except Exception:
|
||||
# Canine cannot add tokens which are not codepoints
|
||||
self.skipTest("Cannot add those Added tokens")
|
||||
|
||||
# XXX: This used to split on `extra_id_1` first we're matching
|
||||
# longest first now.
|
||||
tokens = tokenizer.tokenize("This is some extra_id_100")
|
||||
self.assertIn("extra_id_100", tokens)
|
||||
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
tokenizer.add_tokens([AddedToken("extra_id_100")])
|
||||
tokenizer.add_tokens([AddedToken("extra_id_1")])
|
||||
|
||||
tokens = tokenizer.tokenize("This is some extra_id_100")
|
||||
self.assertIn("extra_id_100", tokens)
|
||||
|
||||
@require_tokenizers
|
||||
def test_added_token_serializable(self):
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||
@@ -3489,3 +3517,21 @@ class TokenizerPushToHubTester(unittest.TestCase):
|
||||
|
||||
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
|
||||
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
||||
|
||||
|
||||
class TrieTest(unittest.TestCase):
|
||||
def test_trie(self):
|
||||
trie = Trie()
|
||||
trie.add("Hello 友達")
|
||||
self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}})
|
||||
trie.add("Hello")
|
||||
trie.data
|
||||
self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}})
|
||||
|
||||
def test_trie_split(self):
|
||||
trie = Trie()
|
||||
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS] This is a extra_id_100"])
|
||||
trie.add("[CLS]")
|
||||
trie.add("extra_id_1")
|
||||
trie.add("extra_id_100")
|
||||
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])
|
||||
|
||||
Reference in New Issue
Block a user