Token healing (#30081)

* token healing impl + trie with extensions

* make fixup

* prefix-robust space tokenization

* examples readme and requirements

* make fixup

* allow input prompt and model

* redundant defaults

* Specialized Trie

* make fixup

* updated tests with new inherited Tree

* input ids to auto device_map

* rm unused import

* Update src/transformers/generation/utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* naming convention

* Revert "naming convention"

This reverts commit dd39d9c5b7a969e2d8a8d2a8e54f121b82dc44f0.

* naming convention

* last -hopefully- changes

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Ahmed Moubtahij
2024-06-03 04:53:15 -04:00
committed by GitHub
parent 5b5b48b11d
commit 39b2ff69d6
7 changed files with 324 additions and 5 deletions

View File

@@ -33,7 +33,7 @@ from transformers import (
is_tokenizers_available,
)
from transformers.testing_utils import TOKEN, USER, is_staging_test, require_tokenizers
from transformers.tokenization_utils import Trie
from transformers.tokenization_utils import ExtensionsTrie, Trie
sys.path.append(str(Path(__file__).parent.parent / "utils"))
@@ -274,3 +274,35 @@ class TrieTest(unittest.TestCase):
trie = Trie()
parts = trie.cut_text("ABC", [0, 0, 2, 1, 2, 3])
self.assertEqual(parts, ["AB", "C"])
class ExtensionsTrieTest(unittest.TestCase):
def test_extensions(self):
# Test searching by prefix
trie = ExtensionsTrie()
trie.add("foo")
trie.add("food")
trie.add("foodie")
trie.add("helium")
self.assertEqual(trie.extensions("foo"), ["foo", "food", "foodie"])
self.assertEqual(trie.extensions("helium"), ["helium"])
def test_empty_prefix(self):
trie = ExtensionsTrie()
# Test searching with an empty prefix returns all values
trie.add("hello")
trie.add("bye")
self.assertEqual(trie.extensions(""), ["hello", "bye"])
def test_no_extension_match(self):
trie = ExtensionsTrie()
# Test searching for a prefix that doesn't match any key
with self.assertRaises(KeyError):
trie.extensions("unknown")
def test_update_value(self):
trie = ExtensionsTrie()
# Test updating the value of an existing key
trie.add("hi")
trie.add("hi")
self.assertEqual(trie.extensions("hi"), ["hi"])