Token healing (#30081)
* token healing impl + trie with extensions * make fixup * prefix-robust space tokenization * examples readme and requirements * make fixup * allow input prompt and model * redundant defaults * Specialized Trie * make fixup * updated tests with new inherited Tree * input ids to auto device_map * rm unused import * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * naming convention * Revert "naming convention" This reverts commit dd39d9c5b7a969e2d8a8d2a8e54f121b82dc44f0. * naming convention * last -hopefully- changes --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -27,6 +27,7 @@ from transformers import is_torch_available, pipeline, set_seed
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_accelerate,
|
||||
require_auto_gptq,
|
||||
require_quanto,
|
||||
require_torch,
|
||||
require_torch_multi_accelerator,
|
||||
@@ -3066,6 +3067,43 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertTrue(y_prob > 0.001 and n_prob > 0.001)
|
||||
self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TokenHealingTestCase(unittest.TestCase):
|
||||
@parameterized.expand(
|
||||
[
|
||||
(
|
||||
"square_bracket",
|
||||
'An example ["like this"] and another example [',
|
||||
'An example ["like this"] and another example ["',
|
||||
),
|
||||
("url", 'The link is <a href="http:', 'The link is <a href="http://'),
|
||||
# aggressive_healing: "http" shouldn't be replaced with "https"
|
||||
("aggressive_healing", 'The link is <a href="http', 'The link is <a href="http'),
|
||||
("trailing_whitespace", "I read a book about ", "I read a book about"),
|
||||
("nothing_to_heal", "I read a book about", "I read a book about"),
|
||||
("single_token", "I", "I"),
|
||||
("empty_prompt", "", ""),
|
||||
]
|
||||
)
|
||||
@require_auto_gptq
|
||||
def test_prompts(self, name, input, expected):
|
||||
model_name_or_path = "TheBloke/deepseek-llm-7B-base-GPTQ"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
|
||||
completion_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path,
|
||||
device_map="auto",
|
||||
trust_remote_code=False,
|
||||
revision="main",
|
||||
use_cache=True,
|
||||
)
|
||||
input_ids = tokenizer(input, return_tensors="pt").input_ids.to(completion_model.device)
|
||||
|
||||
healed_ids = completion_model.heal_tokens(input_ids)
|
||||
predicted = tokenizer.decode(healed_ids[0], skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(predicted, expected)
|
||||
|
||||
def test_generate_from_inputs_embeds_with_bos_token_id_is_none(self):
|
||||
article = "Today a dragon flew over Paris."
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
|
||||
@@ -33,7 +33,7 @@ from transformers import (
|
||||
is_tokenizers_available,
|
||||
)
|
||||
from transformers.testing_utils import TOKEN, USER, is_staging_test, require_tokenizers
|
||||
from transformers.tokenization_utils import Trie
|
||||
from transformers.tokenization_utils import ExtensionsTrie, Trie
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||
@@ -274,3 +274,35 @@ class TrieTest(unittest.TestCase):
|
||||
trie = Trie()
|
||||
parts = trie.cut_text("ABC", [0, 0, 2, 1, 2, 3])
|
||||
self.assertEqual(parts, ["AB", "C"])
|
||||
|
||||
|
||||
class ExtensionsTrieTest(unittest.TestCase):
|
||||
def test_extensions(self):
|
||||
# Test searching by prefix
|
||||
trie = ExtensionsTrie()
|
||||
trie.add("foo")
|
||||
trie.add("food")
|
||||
trie.add("foodie")
|
||||
trie.add("helium")
|
||||
self.assertEqual(trie.extensions("foo"), ["foo", "food", "foodie"])
|
||||
self.assertEqual(trie.extensions("helium"), ["helium"])
|
||||
|
||||
def test_empty_prefix(self):
|
||||
trie = ExtensionsTrie()
|
||||
# Test searching with an empty prefix returns all values
|
||||
trie.add("hello")
|
||||
trie.add("bye")
|
||||
self.assertEqual(trie.extensions(""), ["hello", "bye"])
|
||||
|
||||
def test_no_extension_match(self):
|
||||
trie = ExtensionsTrie()
|
||||
# Test searching for a prefix that doesn't match any key
|
||||
with self.assertRaises(KeyError):
|
||||
trie.extensions("unknown")
|
||||
|
||||
def test_update_value(self):
|
||||
trie = ExtensionsTrie()
|
||||
# Test updating the value of an existing key
|
||||
trie.add("hi")
|
||||
trie.add("hi")
|
||||
self.assertEqual(trie.extensions("hi"), ["hi"])
|
||||
|
||||
Reference in New Issue
Block a user