Fixing 1-length special tokens cut. (#13862)
This commit is contained in:
@@ -20,6 +20,7 @@ import bisect
|
||||
import itertools
|
||||
import re
|
||||
import unicodedata
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, overload
|
||||
|
||||
from .file_utils import PaddingStrategy, TensorType, add_end_docstrings
|
||||
@@ -102,7 +103,6 @@ class Trie:
|
||||
>>> trie.split("[CLS] This is a extra_id_100")
|
||||
["[CLS]", " This is a ", "extra_id_100"]
|
||||
"""
|
||||
|
||||
# indexes are counted left of the chars index.
|
||||
# "hello", index 0, is left of h, index 1 is between h and e.
|
||||
# index 5 is right of the "o".
|
||||
@@ -115,7 +115,7 @@ class Trie:
|
||||
# If the trie contains, "blowing", and "lower" and we encounter the
|
||||
# string "blower", we need to split into ["b", "lower"].
|
||||
# This is where we need to keep track of multiple possible starts.
|
||||
states = {}
|
||||
states = OrderedDict()
|
||||
|
||||
# This will contain every indices where we need
|
||||
# to cut.
|
||||
@@ -144,36 +144,36 @@ class Trie:
|
||||
|
||||
# In this case, we already have partial matches (But unfinished)
|
||||
for start, trie_pointer in states.items():
|
||||
if current_char in trie_pointer:
|
||||
if "" in trie_pointer:
|
||||
# This is a final match, we need to reset and
|
||||
# store the results in `offsets`.
|
||||
|
||||
# Lookahead to match longest first
|
||||
# Important in case of extra_id_1 vs extra_id_100
|
||||
lookahead_index = current
|
||||
end = current
|
||||
next_char = text[lookahead_index] if lookahead_index < len(text) else None
|
||||
while next_char in trie_pointer:
|
||||
trie_pointer = trie_pointer[next_char]
|
||||
lookahead_index += 1
|
||||
if "" in trie_pointer:
|
||||
end = lookahead_index
|
||||
skip = lookahead_index
|
||||
|
||||
if lookahead_index == len(text):
|
||||
# End of string
|
||||
break
|
||||
next_char = text[lookahead_index]
|
||||
# End lookahead
|
||||
|
||||
# Storing and resetting
|
||||
offsets.append(start)
|
||||
offsets.append(end)
|
||||
reset = True
|
||||
elif current_char in trie_pointer:
|
||||
# The current character being looked at has a match within the trie
|
||||
# update the pointer (it will be stored back into states later).
|
||||
trie_pointer = trie_pointer[current_char]
|
||||
if "" in trie_pointer:
|
||||
# This is a final match, we need to reset and
|
||||
# store the results in `offsets`.
|
||||
|
||||
# Lookahead to match longest first
|
||||
# Important in case of extra_id_1 vs extra_id_100
|
||||
lookahead_index = current + 1
|
||||
end = current + 1
|
||||
next_char = text[lookahead_index] if lookahead_index < len(text) else None
|
||||
while next_char in trie_pointer:
|
||||
trie_pointer = trie_pointer[next_char]
|
||||
lookahead_index += 1
|
||||
if "" in trie_pointer:
|
||||
end = lookahead_index
|
||||
skip = lookahead_index
|
||||
|
||||
if lookahead_index == len(text):
|
||||
# End of string
|
||||
break
|
||||
next_char = text[lookahead_index]
|
||||
# End lookahead
|
||||
|
||||
# Storing and resetting
|
||||
offsets.append(start)
|
||||
offsets.append(end)
|
||||
reset = True
|
||||
|
||||
# Storing back the new pointer into the states.
|
||||
# Partial matches got longer by one.
|
||||
@@ -198,6 +198,18 @@ class Trie:
|
||||
if current_char in self.data:
|
||||
states[current] = self.data[current_char]
|
||||
|
||||
# We have a cut at the end with states.
|
||||
for start, trie_pointer in states.items():
|
||||
if "" in trie_pointer:
|
||||
# This is a final match, we need to reset and
|
||||
# store the results in `offsets`.
|
||||
end = len(text)
|
||||
offsets.append(start)
|
||||
offsets.append(end)
|
||||
# Longest cut is always the one with lower start so the first
|
||||
# item so we need to break.
|
||||
break
|
||||
|
||||
# We have all the offsets now, we just need to do the actual splitting.
|
||||
# We need to eventually add the first part of the string and the eventual
|
||||
# last part.
|
||||
|
||||
@@ -3562,3 +3562,15 @@ class TrieTest(unittest.TestCase):
|
||||
trie.add("extra_id_1")
|
||||
trie.add("extra_id_100")
|
||||
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])
|
||||
|
||||
def test_trie_single(self):
|
||||
trie = Trie()
|
||||
trie.add("A")
|
||||
self.assertEqual(trie.split("ABC"), ["A", "BC"])
|
||||
self.assertEqual(trie.split("BCA"), ["BC", "A"])
|
||||
|
||||
def test_trie_final(self):
|
||||
trie = Trie()
|
||||
trie.add("TOKEN]")
|
||||
trie.add("[SPECIAL_TOKEN]")
|
||||
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
|
||||
|
||||
Reference in New Issue
Block a user