Token healing (#30081)
* token healing impl + trie with extensions * make fixup * prefix-robust space tokenization * examples readme and requirements * make fixup * allow input prompt and model * redundant defaults * Specialized Trie * make fixup * updated tests with new inherited Tree * input ids to auto device_map * rm unused import * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * naming convention * Revert "naming convention" This reverts commit dd39d9c5b7a969e2d8a8d2a8e54f121b82dc44f0. * naming convention * last -hopefully- changes --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -33,7 +33,7 @@ from transformers import (
|
||||
is_tokenizers_available,
|
||||
)
|
||||
from transformers.testing_utils import TOKEN, USER, is_staging_test, require_tokenizers
|
||||
from transformers.tokenization_utils import Trie
|
||||
from transformers.tokenization_utils import ExtensionsTrie, Trie
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||
@@ -274,3 +274,35 @@ class TrieTest(unittest.TestCase):
|
||||
trie = Trie()
|
||||
parts = trie.cut_text("ABC", [0, 0, 2, 1, 2, 3])
|
||||
self.assertEqual(parts, ["AB", "C"])
|
||||
|
||||
|
||||
class ExtensionsTrieTest(unittest.TestCase):
|
||||
def test_extensions(self):
|
||||
# Test searching by prefix
|
||||
trie = ExtensionsTrie()
|
||||
trie.add("foo")
|
||||
trie.add("food")
|
||||
trie.add("foodie")
|
||||
trie.add("helium")
|
||||
self.assertEqual(trie.extensions("foo"), ["foo", "food", "foodie"])
|
||||
self.assertEqual(trie.extensions("helium"), ["helium"])
|
||||
|
||||
def test_empty_prefix(self):
|
||||
trie = ExtensionsTrie()
|
||||
# Test searching with an empty prefix returns all values
|
||||
trie.add("hello")
|
||||
trie.add("bye")
|
||||
self.assertEqual(trie.extensions(""), ["hello", "bye"])
|
||||
|
||||
def test_no_extension_match(self):
|
||||
trie = ExtensionsTrie()
|
||||
# Test searching for a prefix that doesn't match any key
|
||||
with self.assertRaises(KeyError):
|
||||
trie.extensions("unknown")
|
||||
|
||||
def test_update_value(self):
|
||||
trie = ExtensionsTrie()
|
||||
# Test updating the value of an existing key
|
||||
trie.add("hi")
|
||||
trie.add("hi")
|
||||
self.assertEqual(trie.extensions("hi"), ["hi"])
|
||||
|
||||
Reference in New Issue
Block a user