Adding support for tokens being suffixes or part of each other. (#13918)
* Adding support for tokens being suffixes or part of each other. * Better test name.
This commit is contained in:
@@ -150,26 +150,44 @@ class Trie:
|
||||
|
||||
# Lookahead to match longest first
|
||||
# Important in case of extra_id_1 vs extra_id_100
|
||||
lookahead_index = current
|
||||
end = current
|
||||
next_char = text[lookahead_index] if lookahead_index < len(text) else None
|
||||
while next_char in trie_pointer:
|
||||
trie_pointer = trie_pointer[next_char]
|
||||
lookahead_index += 1
|
||||
if "" in trie_pointer:
|
||||
end = lookahead_index
|
||||
skip = lookahead_index
|
||||
|
||||
if lookahead_index == len(text):
|
||||
# End of string
|
||||
# Here we are also actively looking for other earlier partial
|
||||
# matches
|
||||
# "[CLS]", "L", we need to match CLS even if L is special
|
||||
for lookstart, looktrie_pointer in states.items():
|
||||
if lookstart > start:
|
||||
# This partial match is later, we can stop looking
|
||||
break
|
||||
next_char = text[lookahead_index]
|
||||
# End lookahead
|
||||
elif lookstart < start:
|
||||
# This partial match is earlier, the trie pointer
|
||||
# was already updated, so index is + 1
|
||||
lookahead_index = current + 1
|
||||
end = current + 1
|
||||
else:
|
||||
# Here lookstart == start and
|
||||
# looktrie_pointer == trie_pointer
|
||||
# It wasn't updated yet so indices are current ones
|
||||
lookahead_index = current
|
||||
end = current
|
||||
next_char = text[lookahead_index] if lookahead_index < len(text) else None
|
||||
while next_char in looktrie_pointer:
|
||||
looktrie_pointer = looktrie_pointer[next_char]
|
||||
lookahead_index += 1
|
||||
if "" in looktrie_pointer:
|
||||
start = lookstart
|
||||
end = lookahead_index
|
||||
skip = lookahead_index
|
||||
|
||||
if lookahead_index == len(text):
|
||||
# End of string
|
||||
break
|
||||
next_char = text[lookahead_index]
|
||||
# End lookahead
|
||||
|
||||
# Storing and resetting
|
||||
offsets.append(start)
|
||||
offsets.append(end)
|
||||
reset = True
|
||||
break
|
||||
elif current_char in trie_pointer:
|
||||
# The current character being looked at has a match within the trie
|
||||
# update the pointer (it will be stored back into states later).
|
||||
@@ -210,6 +228,9 @@ class Trie:
|
||||
# item so we need to break.
|
||||
break
|
||||
|
||||
return self.cut_text(text, offsets)
|
||||
|
||||
def cut_text(self, text, offsets):
|
||||
# We have all the offsets now, we just need to do the actual splitting.
|
||||
# We need to eventually add the first part of the string and the eventual
|
||||
# last part.
|
||||
@@ -217,7 +238,12 @@ class Trie:
|
||||
tokens = []
|
||||
start = 0
|
||||
for end in offsets:
|
||||
if start == end:
|
||||
if start > end:
|
||||
logger.error(
|
||||
"There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it anyway."
|
||||
)
|
||||
continue
|
||||
elif start == end:
|
||||
# This might happen if there's a match at index 0
|
||||
# we're also preventing zero-width cuts in case of two
|
||||
# consecutive matches
|
||||
|
||||
@@ -3574,3 +3574,24 @@ class TrieTest(unittest.TestCase):
|
||||
trie.add("TOKEN]")
|
||||
trie.add("[SPECIAL_TOKEN]")
|
||||
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
|
||||
|
||||
def test_trie_subtokens(self):
|
||||
trie = Trie()
|
||||
trie.add("A")
|
||||
trie.add("P")
|
||||
trie.add("[SPECIAL_TOKEN]")
|
||||
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
|
||||
|
||||
def test_trie_suffix_tokens(self):
|
||||
trie = Trie()
|
||||
trie.add("AB")
|
||||
trie.add("B")
|
||||
trie.add("C")
|
||||
self.assertEqual(trie.split("ABC"), ["AB", "C"])
|
||||
|
||||
def test_cut_text_hardening(self):
|
||||
# Even if the offsets are wrong, we necessarily output correct string
|
||||
# parts.
|
||||
trie = Trie()
|
||||
parts = trie.cut_text("ABC", [0, 0, 2, 1, 2, 3])
|
||||
self.assertEqual(parts, ["AB", "C"])
|
||||
|
||||
Reference in New Issue
Block a user