Fixing 1-length special tokens cut. (#13862)

This commit is contained in:
Nicolas Patry
2021-10-05 12:26:54 +02:00
committed by GitHub
parent 7051b89267
commit 7079a99e76
2 changed files with 53 additions and 29 deletions

View File

@@ -3562,3 +3562,15 @@ class TrieTest(unittest.TestCase):
trie.add("extra_id_1")
trie.add("extra_id_100")
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])
def test_trie_single(self):
trie = Trie()
trie.add("A")
self.assertEqual(trie.split("ABC"), ["A", "BC"])
self.assertEqual(trie.split("BCA"), ["BC", "A"])
def test_trie_final(self):
trie = Trie()
trie.add("TOKEN]")
trie.add("[SPECIAL_TOKEN]")
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])