Fixing a pathological case for slow tokenizers (#14981)
* Fixing a pathological case for slow tokenizers * Update src/transformers/tokenization_utils.py
This commit is contained in:
@@ -131,7 +131,7 @@ class Trie:
|
|||||||
# This is used by the lookahead which needs to skip over
|
# This is used by the lookahead which needs to skip over
|
||||||
# some text where the full match exceeded the place in the initial
|
# some text where the full match exceeded the place in the initial
|
||||||
# for loop
|
# for loop
|
||||||
skip = None
|
skip = 0
|
||||||
# Main loop, Giving this algorithm O(n) complexity
|
# Main loop, Giving this algorithm O(n) complexity
|
||||||
for current, current_char in enumerate(text):
|
for current, current_char in enumerate(text):
|
||||||
if skip and current < skip:
|
if skip and current < skip:
|
||||||
@@ -175,6 +175,11 @@ class Trie:
|
|||||||
lookahead_index = current
|
lookahead_index = current
|
||||||
end = current
|
end = current
|
||||||
next_char = text[lookahead_index] if lookahead_index < len(text) else None
|
next_char = text[lookahead_index] if lookahead_index < len(text) else None
|
||||||
|
if "" in looktrie_pointer:
|
||||||
|
start = lookstart
|
||||||
|
end = lookahead_index
|
||||||
|
skip = lookahead_index
|
||||||
|
|
||||||
while next_char in looktrie_pointer:
|
while next_char in looktrie_pointer:
|
||||||
looktrie_pointer = looktrie_pointer[next_char]
|
looktrie_pointer = looktrie_pointer[next_char]
|
||||||
lookahead_index += 1
|
lookahead_index += 1
|
||||||
@@ -219,7 +224,7 @@ class Trie:
|
|||||||
|
|
||||||
# If this character is a starting character within the trie
|
# If this character is a starting character within the trie
|
||||||
# start keeping track of this partial match.
|
# start keeping track of this partial match.
|
||||||
if current_char in self.data:
|
if current >= skip and current_char in self.data:
|
||||||
states[current] = self.data[current_char]
|
states[current] = self.data[current_char]
|
||||||
|
|
||||||
# We have a cut at the end with states.
|
# We have a cut at the end with states.
|
||||||
|
|||||||
@@ -3687,6 +3687,13 @@ class TrieTest(unittest.TestCase):
|
|||||||
trie.add("C")
|
trie.add("C")
|
||||||
self.assertEqual(trie.split("ABC"), ["AB", "C"])
|
self.assertEqual(trie.split("ABC"), ["AB", "C"])
|
||||||
|
|
||||||
|
def test_trie_skip(self):
|
||||||
|
trie = Trie()
|
||||||
|
trie.add("ABC")
|
||||||
|
trie.add("B")
|
||||||
|
trie.add("CD")
|
||||||
|
self.assertEqual(trie.split("ABCD"), ["ABC", "D"])
|
||||||
|
|
||||||
def test_cut_text_hardening(self):
|
def test_cut_text_hardening(self):
|
||||||
# Even if the offsets are wrong, we necessarily output correct string
|
# Even if the offsets are wrong, we necessarily output correct string
|
||||||
# parts.
|
# parts.
|
||||||
|
|||||||
Reference in New Issue
Block a user