From d70919e6d5629f3f4d19aefff3eec1219055f003 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 8 Oct 2021 10:10:38 +0200 Subject: [PATCH] Adding support for tokens being suffixes or part of each other. (#13918) * Adding support for tokens being suffixes or part of each other. * Better test name. --- src/transformers/tokenization_utils.py | 56 +++++++++++++++++++------- tests/test_tokenization_common.py | 21 ++++++++++ 2 files changed, 62 insertions(+), 15 deletions(-) diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index f10267fd4b..a8bcb98f85 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -150,26 +150,44 @@ class Trie: # Lookahead to match longest first # Important in case of extra_id_1 vs extra_id_100 - lookahead_index = current - end = current - next_char = text[lookahead_index] if lookahead_index < len(text) else None - while next_char in trie_pointer: - trie_pointer = trie_pointer[next_char] - lookahead_index += 1 - if "" in trie_pointer: - end = lookahead_index - skip = lookahead_index - - if lookahead_index == len(text): - # End of string + # Here we are also actively looking for other earlier partial + # matches + # "[CLS]", "L", we need to match CLS even if L is special + for lookstart, looktrie_pointer in states.items(): + if lookstart > start: + # This partial match is later, we can stop looking break - next_char = text[lookahead_index] - # End lookahead + elif lookstart < start: + # This partial match is earlier, the trie pointer + # was already updated, so index is + 1 + lookahead_index = current + 1 + end = current + 1 + else: + # Here lookstart == start and + # looktrie_pointer == trie_pointer + # It wasn't updated yet so indices are current ones + lookahead_index = current + end = current + next_char = text[lookahead_index] if lookahead_index < len(text) else None + while next_char in looktrie_pointer: + looktrie_pointer = looktrie_pointer[next_char] + lookahead_index += 1 + if "" in looktrie_pointer: + start = lookstart + end = lookahead_index + skip = lookahead_index + + if lookahead_index == len(text): + # End of string + break + next_char = text[lookahead_index] + # End lookahead # Storing and resetting offsets.append(start) offsets.append(end) reset = True + break elif current_char in trie_pointer: # The current character being looked at has a match within the trie # update the pointer (it will be stored back into states later). @@ -210,6 +228,9 @@ class Trie: # item so we need to break. break + return self.cut_text(text, offsets) + + def cut_text(self, text, offsets): # We have all the offsets now, we just need to do the actual splitting. # We need to eventually add the first part of the string and the eventual # last part. @@ -217,7 +238,12 @@ class Trie: tokens = [] start = 0 for end in offsets: - if start == end: + if start > end: + logger.error( + "There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it anyway." + ) + continue + elif start == end: # This might happen if there's a match at index 0 # we're also preventing zero-width cuts in case of two # consecutive matches diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 1a58f51692..36a9320541 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -3574,3 +3574,24 @@ class TrieTest(unittest.TestCase): trie.add("TOKEN]") trie.add("[SPECIAL_TOKEN]") self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"]) + + def test_trie_subtokens(self): + trie = Trie() + trie.add("A") + trie.add("P") + trie.add("[SPECIAL_TOKEN]") + self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"]) + + def test_trie_suffix_tokens(self): + trie = Trie() + trie.add("AB") + trie.add("B") + trie.add("C") + self.assertEqual(trie.split("ABC"), ["AB", "C"]) + + def test_cut_text_hardening(self): + # Even if the offsets are wrong, we necessarily output correct string + # parts. + trie = Trie() + parts = trie.cut_text("ABC", [0, 0, 2, 1, 2, 3]) + self.assertEqual(parts, ["AB", "C"])