Adding support for tokens being suffixes or part of each other. (#13918)
* Adding support for tokens being suffixes or part of each other. * Better test name.
This commit is contained in:
@@ -150,26 +150,44 @@ class Trie:
|
|||||||
|
|
||||||
# Lookahead to match longest first
|
# Lookahead to match longest first
|
||||||
# Important in case of extra_id_1 vs extra_id_100
|
# Important in case of extra_id_1 vs extra_id_100
|
||||||
lookahead_index = current
|
# Here we are also actively looking for other earlier partial
|
||||||
end = current
|
# matches
|
||||||
next_char = text[lookahead_index] if lookahead_index < len(text) else None
|
# "[CLS]", "L", we need to match CLS even if L is special
|
||||||
while next_char in trie_pointer:
|
for lookstart, looktrie_pointer in states.items():
|
||||||
trie_pointer = trie_pointer[next_char]
|
if lookstart > start:
|
||||||
lookahead_index += 1
|
# This partial match is later, we can stop looking
|
||||||
if "" in trie_pointer:
|
|
||||||
end = lookahead_index
|
|
||||||
skip = lookahead_index
|
|
||||||
|
|
||||||
if lookahead_index == len(text):
|
|
||||||
# End of string
|
|
||||||
break
|
break
|
||||||
next_char = text[lookahead_index]
|
elif lookstart < start:
|
||||||
# End lookahead
|
# This partial match is earlier, the trie pointer
|
||||||
|
# was already updated, so index is + 1
|
||||||
|
lookahead_index = current + 1
|
||||||
|
end = current + 1
|
||||||
|
else:
|
||||||
|
# Here lookstart == start and
|
||||||
|
# looktrie_pointer == trie_pointer
|
||||||
|
# It wasn't updated yet so indices are current ones
|
||||||
|
lookahead_index = current
|
||||||
|
end = current
|
||||||
|
next_char = text[lookahead_index] if lookahead_index < len(text) else None
|
||||||
|
while next_char in looktrie_pointer:
|
||||||
|
looktrie_pointer = looktrie_pointer[next_char]
|
||||||
|
lookahead_index += 1
|
||||||
|
if "" in looktrie_pointer:
|
||||||
|
start = lookstart
|
||||||
|
end = lookahead_index
|
||||||
|
skip = lookahead_index
|
||||||
|
|
||||||
|
if lookahead_index == len(text):
|
||||||
|
# End of string
|
||||||
|
break
|
||||||
|
next_char = text[lookahead_index]
|
||||||
|
# End lookahead
|
||||||
|
|
||||||
# Storing and resetting
|
# Storing and resetting
|
||||||
offsets.append(start)
|
offsets.append(start)
|
||||||
offsets.append(end)
|
offsets.append(end)
|
||||||
reset = True
|
reset = True
|
||||||
|
break
|
||||||
elif current_char in trie_pointer:
|
elif current_char in trie_pointer:
|
||||||
# The current character being looked at has a match within the trie
|
# The current character being looked at has a match within the trie
|
||||||
# update the pointer (it will be stored back into states later).
|
# update the pointer (it will be stored back into states later).
|
||||||
@@ -210,6 +228,9 @@ class Trie:
|
|||||||
# item so we need to break.
|
# item so we need to break.
|
||||||
break
|
break
|
||||||
|
|
||||||
|
return self.cut_text(text, offsets)
|
||||||
|
|
||||||
|
def cut_text(self, text, offsets):
|
||||||
# We have all the offsets now, we just need to do the actual splitting.
|
# We have all the offsets now, we just need to do the actual splitting.
|
||||||
# We need to eventually add the first part of the string and the eventual
|
# We need to eventually add the first part of the string and the eventual
|
||||||
# last part.
|
# last part.
|
||||||
@@ -217,7 +238,12 @@ class Trie:
|
|||||||
tokens = []
|
tokens = []
|
||||||
start = 0
|
start = 0
|
||||||
for end in offsets:
|
for end in offsets:
|
||||||
if start == end:
|
if start > end:
|
||||||
|
logger.error(
|
||||||
|
"There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it anyway."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
elif start == end:
|
||||||
# This might happen if there's a match at index 0
|
# This might happen if there's a match at index 0
|
||||||
# we're also preventing zero-width cuts in case of two
|
# we're also preventing zero-width cuts in case of two
|
||||||
# consecutive matches
|
# consecutive matches
|
||||||
|
|||||||
@@ -3574,3 +3574,24 @@ class TrieTest(unittest.TestCase):
|
|||||||
trie.add("TOKEN]")
|
trie.add("TOKEN]")
|
||||||
trie.add("[SPECIAL_TOKEN]")
|
trie.add("[SPECIAL_TOKEN]")
|
||||||
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
|
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
|
||||||
|
|
||||||
|
def test_trie_subtokens(self):
|
||||||
|
trie = Trie()
|
||||||
|
trie.add("A")
|
||||||
|
trie.add("P")
|
||||||
|
trie.add("[SPECIAL_TOKEN]")
|
||||||
|
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
|
||||||
|
|
||||||
|
def test_trie_suffix_tokens(self):
|
||||||
|
trie = Trie()
|
||||||
|
trie.add("AB")
|
||||||
|
trie.add("B")
|
||||||
|
trie.add("C")
|
||||||
|
self.assertEqual(trie.split("ABC"), ["AB", "C"])
|
||||||
|
|
||||||
|
def test_cut_text_hardening(self):
|
||||||
|
# Even if the offsets are wrong, we necessarily output correct string
|
||||||
|
# parts.
|
||||||
|
trie = Trie()
|
||||||
|
parts = trie.cut_text("ABC", [0, 0, 2, 1, 2, 3])
|
||||||
|
self.assertEqual(parts, ["AB", "C"])
|
||||||
|
|||||||
Reference in New Issue
Block a user