From 3dd538c4d37248961d4cf99f4c07e8a5fe54984c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 9 Sep 2021 17:26:16 +0200 Subject: [PATCH] [Tentative] Moving slow tokenizer to the Trie world. (#13220) * Moving slow tokenizer to the Trie world. * Adding more docstrings to the Trie. * Fixing doctest (incompatible wiht our format? ) * Update src/transformers/tokenization_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Adding a lot more comment into the internals of this algorithm. * Cleaner doc. * Fixing the namings. * Update src/transformers/tokenization_utils.py Co-authored-by: Lysandre Debut * quality. * Fixing longest first match. * Small improvements to cuts + more test + canine resistant test. * Fixing fast test. Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Lysandre Debut --- .../models/wav2vec2/tokenization_wav2vec2.py | 4 + src/transformers/tokenization_utils.py | 280 +++++++++++++----- tests/test_tokenization_common.py | 48 ++- 3 files changed, 256 insertions(+), 76 deletions(-) diff --git a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py index e6d1092b1e..0c8eb31d01 100644 --- a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py +++ b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py @@ -148,6 +148,8 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): if len(token) > 1: self.unique_no_split_tokens.append(token) + self._create_trie(self.unique_no_split_tokens) + @property def word_delimiter_token(self) -> str: """ @@ -330,6 +332,8 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): self._additional_special_tokens.append(AddedToken(token)) _insert_one_token_to_ordered_list(self.unique_no_split_tokens, token) + self._create_trie(self.unique_no_split_tokens) + return len(tokens_to_add) diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index b4e370803c..0146c0ef0d 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -49,6 +49,173 @@ ADDED_TOKENS_FILE = "added_tokens.json" TOKENIZER_CONFIG_FILE = "tokenizer_config.json" +class Trie: + """ + Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass + Loose reference https://en.wikipedia.org/wiki/Trie + """ + + def __init__(self): + self.data = {} + + def add(self, word: str): + """ + Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation. + The special key `""` is used to represent termination. + + This function is idempotent, adding twice the same word will leave the trie unchanged + + Example:: + + >>> trie = Trie() + >>> trie.add("Hello 友達") + >>> trie.data + {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}} + >>> trie.add("Hello") + >>> trie.data + {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}} + """ + if not word: + # Prevent empty string + return + ref = self.data + for char in word: + ref[char] = char in ref and ref[char] or {} + ref = ref[char] + ref[""] = 1 + + def split(self, text: str) -> List[str]: + """ + Will look for the words added to the trie within `text`. Output is the original string splitted along the + boundaries of the words found. + + This trie will match the longest possible word first ! + + Example:: + + >>> trie = Trie() + >>> trie.split("[CLS] This is a extra_id_100") + ["[CLS] This is a extra_id_100"] + >>> trie.add("[CLS]") + >>> trie.add("extra_id_1") + >>> trie.add("extra_id_100") + >>> trie.split("[CLS] This is a extra_id_100") + ["[CLS]", " This is a ", "extra_id_100"] + """ + + # indexes are counted left of the chars index. + # "hello", index 0, is left of h, index 1 is between h and e. + # index 5 is right of the "o". + + # States are going to capture every possible start (indexes as above) + # as keys, and have as values, a pointer to the position in the trie + # where we're at. This is a partial match for now. + # This enables to keep track of multiple matches while we're iterating + # the string + # If the trie contains, "blowing", and "lower" and we encounter the + # string "blower", we need to split into ["b", "lower"]. + # This is where we need to keep track of multiple possible starts. + states = {} + + # This will contain every indices where we need + # to cut. + # We force to cut at offset 0 and len(text) (added later) + offsets = [0] + + # This is used by the lookahead which needs to skip over + # some text where the full match exceeded the place in the initial + # for loop + skip = None + # Main loop, Giving this algorithm O(n) complexity + for current, current_char in enumerate(text): + if skip and current < skip: + # Prevents the lookahead for matching twice + # like extra_id_100 and id_100 + continue + + # This will track every state + # that stop matching, we need to stop tracking them. + # If we look at "lowball", we're going to match "l" (add it to states), "o", "w", then + # fail on "b", we need to remove 0 from the valid states. + to_remove = set() + # Whenever we found a match, we need to drop everything + # this is a greedy algorithm, it will match on the first found token + reset = False + + # In this case, we already have partial matches (But unfinished) + for start, trie_pointer in states.items(): + if current_char in trie_pointer: + # The current character being looked at has a match within the trie + # update the pointer (it will be stored back into states later). + trie_pointer = trie_pointer[current_char] + if "" in trie_pointer: + # This is a final match, we need to reset and + # store the results in `offsets`. + + # Lookahead to match longest first + # Important in case of extra_id_1 vs extra_id_100 + lookahead_index = current + 1 + end = current + 1 + next_char = text[lookahead_index] if lookahead_index < len(text) else None + while next_char in trie_pointer: + trie_pointer = trie_pointer[next_char] + lookahead_index += 1 + if "" in trie_pointer: + end = lookahead_index + skip = lookahead_index + + if lookahead_index == len(text): + # End of string + break + next_char = text[lookahead_index] + # End lookahead + + # Storing and resetting + offsets.append(start) + offsets.append(end) + reset = True + + # Storing back the new pointer into the states. + # Partial matches got longer by one. + states[start] = trie_pointer + else: + # The new character has not match in the trie, we need + # to stop keeping track of this partial match. + # We can't do it directly within the loop because of how + # python iteration works + to_remove.add(start) + + # Either clearing the full start (we found a real match) + # Or clearing only the partial matches that didn't work. + if reset: + states = {} + else: + for start in to_remove: + del states[start] + + # If this character is a starting character within the trie + # start keeping track of this partial match. + if current_char in self.data: + states[current] = self.data[current_char] + + # We have all the offsets now, we just need to do the actual splitting. + # We need to eventually add the first part of the string and the eventual + # last part. + offsets.append(len(text)) + tokens = [] + start = 0 + for end in offsets: + if start == end: + # This might happen if there's a match at index 0 + # we're also preventing zero-width cuts in case of two + # consecutive matches + continue + tokens.append(text[start:end]) + start = end + + return tokens + + def _is_whitespace(char): """Checks whether `char` is a whitespace character.""" # \t, \n, and \r are technically control characters but we treat them @@ -135,6 +302,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): self.added_tokens_encoder: Dict[str, int] = {} self.added_tokens_decoder: Dict[int, str] = {} self.unique_no_split_tokens: List[str] = [] + self.tokens_trie = Trie() self._decode_use_source_tokenizer = False @@ -223,9 +391,19 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): _insert_one_token_to_ordered_list(self.unique_no_split_tokens, tokens_to_add[0]) else: self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(tokens_to_add))) + self._create_trie(self.unique_no_split_tokens) return len(tokens_to_add) + def _create_trie(self, unique_no_split_tokens): + trie = Trie() + for token in unique_no_split_tokens: + if hasattr(self, "do_lower_case") and self.do_lower_case and token not in self.all_special_tokens: + trie.add(token.lower()) + else: + trie.add(token) + self.tokens_trie = trie + def num_special_tokens_to_add(self, pair: bool = False) -> int: """ Returns the number of added tokens when encoding a sequence with special tokens. @@ -279,87 +457,39 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text) - def split_on_token(tok, text): - result = [] - tok_extended = all_special_tokens_extended.get(tok, None) - split_text = text.split(tok) - full_word = "" - for i, sub_text in enumerate(split_text): - # AddedToken can control whitespace stripping around them. - # We use them for GPT2 and Roberta to have different behavior depending on the special token - # Cf. https://github.com/huggingface/transformers/pull/2778 - # and https://github.com/huggingface/transformers/issues/3788 + no_split_token = set(self.unique_no_split_tokens) + tokens = self.tokens_trie.split(text) + # ["This is something", "", " else"] + for i, token in enumerate(tokens): + if token in no_split_token: + tok_extended = all_special_tokens_extended.get(token, None) + left = tokens[i - 1] if i > 0 else None + right = tokens[i + 1] if i < len(tokens) - 1 else None if isinstance(tok_extended, AddedToken): - if tok_extended.single_word: - # Try to avoid splitting on token - if ( - i < len(split_text) - 1 - and not _is_end_of_word(sub_text) - and not _is_start_of_word(split_text[i + 1]) - ): - # Don't extract the special token - full_word += sub_text + tok - elif full_word: - full_word += sub_text - result.append(full_word) - full_word = "" - continue - # Strip white spaces on the right - if tok_extended.rstrip and i > 0: + if tok_extended.rstrip and right: # A bit counter-intuitive but we strip the left of the string # since tok_extended.rstrip means the special token is eating all white spaces on its right - sub_text = sub_text.lstrip() + tokens[i + 1] = right.lstrip() # Strip white spaces on the left - if tok_extended.lstrip and i < len(split_text) - 1: - sub_text = sub_text.rstrip() # Opposite here + if tok_extended.lstrip and left: + tokens[i - 1] = left.rstrip() # Opposite here else: # We strip left and right by default - if i < len(split_text) - 1: - sub_text = sub_text.rstrip() - if i > 0: - sub_text = sub_text.lstrip() - - if i == 0 and not sub_text: - result.append(tok) - elif i == len(split_text) - 1: - if sub_text: - result.append(sub_text) - else: - pass - else: - if sub_text: - result.append(sub_text) - result.append(tok) - return result - - def split_on_tokens(tok_list, text): - if not text.strip(): - return [] - if not tok_list: - return self._tokenize(text) - - tokenized_text = [] - text_list = [text] - for tok in tok_list: - tokenized_text = [] - for sub_text in text_list: - if sub_text not in self.unique_no_split_tokens: - tokenized_text.extend(split_on_token(tok, sub_text)) - else: - tokenized_text.append(sub_text) - text_list = tokenized_text - - return list( - itertools.chain.from_iterable( - ( - self._tokenize(token) if token not in self.unique_no_split_tokens else [token] - for token in tokenized_text - ) - ) - ) - - no_split_token = self.unique_no_split_tokens - tokenized_text = split_on_tokens(no_split_token, text) + if right: + tokens[i + 1] = right.lstrip() + if left: + tokens[i - 1] = left.rstrip() + # ["This is something", "", "else"] + tokenized_text = [] + for token in tokens: + # Need to skip eventual empty (fully stripped) tokens + if not token: + continue + if token in no_split_token: + tokenized_text.append(token) + else: + tokenized_text.extend(self._tokenize(token)) + # ["This", " is", " something", "", "else"] return tokenized_text def _tokenize(self, text, **kwargs): diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 5dfc738814..5cea9a8c4a 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -55,7 +55,7 @@ from transformers.testing_utils import ( require_torch, slow, ) -from transformers.tokenization_utils import AddedToken +from transformers.tokenization_utils import AddedToken, Trie if is_torch_available(): @@ -1659,6 +1659,34 @@ class TokenizerTesterMixin: encoded_sequences_batch_padded_2[key], ) + @require_tokenizers + def test_added_token_are_matched_longest_first(self): + if not self.test_slow_tokenizer: + self.skipTest("This test is only for slow tokenizers") + return + tokenizers = self.get_tokenizers(fast=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + try: + tokenizer.add_tokens([AddedToken("extra_id_1")]) + tokenizer.add_tokens([AddedToken("extra_id_100")]) + except Exception: + # Canine cannot add tokens which are not codepoints + self.skipTest("Cannot add those Added tokens") + + # XXX: This used to split on `extra_id_1` first we're matching + # longest first now. + tokens = tokenizer.tokenize("This is some extra_id_100") + self.assertIn("extra_id_100", tokens) + + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + tokenizer.add_tokens([AddedToken("extra_id_100")]) + tokenizer.add_tokens([AddedToken("extra_id_1")]) + + tokens = tokenizer.tokenize("This is some extra_id_100") + self.assertIn("extra_id_100", tokens) + @require_tokenizers def test_added_token_serializable(self): tokenizers = self.get_tokenizers(do_lower_case=False) @@ -3489,3 +3517,21 @@ class TokenizerPushToHubTester(unittest.TestCase): new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org") self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab) + + +class TrieTest(unittest.TestCase): + def test_trie(self): + trie = Trie() + trie.add("Hello 友達") + self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}}) + trie.add("Hello") + trie.data + self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}}) + + def test_trie_split(self): + trie = Trie() + self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS] This is a extra_id_100"]) + trie.add("[CLS]") + trie.add("extra_id_1") + trie.add("extra_id_100") + self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])