Fixing 1-length special tokens cut. (#13862)
This commit is contained in:
@@ -20,6 +20,7 @@ import bisect
|
|||||||
import itertools
|
import itertools
|
||||||
import re
|
import re
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
from collections import OrderedDict
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union, overload
|
from typing import Any, Dict, List, Optional, Tuple, Union, overload
|
||||||
|
|
||||||
from .file_utils import PaddingStrategy, TensorType, add_end_docstrings
|
from .file_utils import PaddingStrategy, TensorType, add_end_docstrings
|
||||||
@@ -102,7 +103,6 @@ class Trie:
|
|||||||
>>> trie.split("[CLS] This is a extra_id_100")
|
>>> trie.split("[CLS] This is a extra_id_100")
|
||||||
["[CLS]", " This is a ", "extra_id_100"]
|
["[CLS]", " This is a ", "extra_id_100"]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# indexes are counted left of the chars index.
|
# indexes are counted left of the chars index.
|
||||||
# "hello", index 0, is left of h, index 1 is between h and e.
|
# "hello", index 0, is left of h, index 1 is between h and e.
|
||||||
# index 5 is right of the "o".
|
# index 5 is right of the "o".
|
||||||
@@ -115,7 +115,7 @@ class Trie:
|
|||||||
# If the trie contains, "blowing", and "lower" and we encounter the
|
# If the trie contains, "blowing", and "lower" and we encounter the
|
||||||
# string "blower", we need to split into ["b", "lower"].
|
# string "blower", we need to split into ["b", "lower"].
|
||||||
# This is where we need to keep track of multiple possible starts.
|
# This is where we need to keep track of multiple possible starts.
|
||||||
states = {}
|
states = OrderedDict()
|
||||||
|
|
||||||
# This will contain every indices where we need
|
# This will contain every indices where we need
|
||||||
# to cut.
|
# to cut.
|
||||||
@@ -144,18 +144,14 @@ class Trie:
|
|||||||
|
|
||||||
# In this case, we already have partial matches (But unfinished)
|
# In this case, we already have partial matches (But unfinished)
|
||||||
for start, trie_pointer in states.items():
|
for start, trie_pointer in states.items():
|
||||||
if current_char in trie_pointer:
|
|
||||||
# The current character being looked at has a match within the trie
|
|
||||||
# update the pointer (it will be stored back into states later).
|
|
||||||
trie_pointer = trie_pointer[current_char]
|
|
||||||
if "" in trie_pointer:
|
if "" in trie_pointer:
|
||||||
# This is a final match, we need to reset and
|
# This is a final match, we need to reset and
|
||||||
# store the results in `offsets`.
|
# store the results in `offsets`.
|
||||||
|
|
||||||
# Lookahead to match longest first
|
# Lookahead to match longest first
|
||||||
# Important in case of extra_id_1 vs extra_id_100
|
# Important in case of extra_id_1 vs extra_id_100
|
||||||
lookahead_index = current + 1
|
lookahead_index = current
|
||||||
end = current + 1
|
end = current
|
||||||
next_char = text[lookahead_index] if lookahead_index < len(text) else None
|
next_char = text[lookahead_index] if lookahead_index < len(text) else None
|
||||||
while next_char in trie_pointer:
|
while next_char in trie_pointer:
|
||||||
trie_pointer = trie_pointer[next_char]
|
trie_pointer = trie_pointer[next_char]
|
||||||
@@ -174,6 +170,10 @@ class Trie:
|
|||||||
offsets.append(start)
|
offsets.append(start)
|
||||||
offsets.append(end)
|
offsets.append(end)
|
||||||
reset = True
|
reset = True
|
||||||
|
elif current_char in trie_pointer:
|
||||||
|
# The current character being looked at has a match within the trie
|
||||||
|
# update the pointer (it will be stored back into states later).
|
||||||
|
trie_pointer = trie_pointer[current_char]
|
||||||
|
|
||||||
# Storing back the new pointer into the states.
|
# Storing back the new pointer into the states.
|
||||||
# Partial matches got longer by one.
|
# Partial matches got longer by one.
|
||||||
@@ -198,6 +198,18 @@ class Trie:
|
|||||||
if current_char in self.data:
|
if current_char in self.data:
|
||||||
states[current] = self.data[current_char]
|
states[current] = self.data[current_char]
|
||||||
|
|
||||||
|
# We have a cut at the end with states.
|
||||||
|
for start, trie_pointer in states.items():
|
||||||
|
if "" in trie_pointer:
|
||||||
|
# This is a final match, we need to reset and
|
||||||
|
# store the results in `offsets`.
|
||||||
|
end = len(text)
|
||||||
|
offsets.append(start)
|
||||||
|
offsets.append(end)
|
||||||
|
# Longest cut is always the one with lower start so the first
|
||||||
|
# item so we need to break.
|
||||||
|
break
|
||||||
|
|
||||||
# We have all the offsets now, we just need to do the actual splitting.
|
# We have all the offsets now, we just need to do the actual splitting.
|
||||||
# We need to eventually add the first part of the string and the eventual
|
# We need to eventually add the first part of the string and the eventual
|
||||||
# last part.
|
# last part.
|
||||||
|
|||||||
@@ -3562,3 +3562,15 @@ class TrieTest(unittest.TestCase):
|
|||||||
trie.add("extra_id_1")
|
trie.add("extra_id_1")
|
||||||
trie.add("extra_id_100")
|
trie.add("extra_id_100")
|
||||||
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])
|
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])
|
||||||
|
|
||||||
|
def test_trie_single(self):
|
||||||
|
trie = Trie()
|
||||||
|
trie.add("A")
|
||||||
|
self.assertEqual(trie.split("ABC"), ["A", "BC"])
|
||||||
|
self.assertEqual(trie.split("BCA"), ["BC", "A"])
|
||||||
|
|
||||||
|
def test_trie_final(self):
|
||||||
|
trie = Trie()
|
||||||
|
trie.add("TOKEN]")
|
||||||
|
trie.add("[SPECIAL_TOKEN]")
|
||||||
|
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
|
||||||
|
|||||||
Reference in New Issue
Block a user