Token healing (#30081)

* token healing impl + trie with extensions

* make fixup

* prefix-robust space tokenization

* examples readme and requirements

* make fixup

* allow input prompt and model

* redundant defaults

* Specialized Trie

* make fixup

* updated tests with new inherited Tree

* input ids to auto device_map

* rm unused import

* Update src/transformers/generation/utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* naming convention

* Revert "naming convention"

This reverts commit dd39d9c5b7a969e2d8a8d2a8e54f121b82dc44f0.

* naming convention

* last -hopefully- changes

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Ahmed Moubtahij
2024-06-03 04:53:15 -04:00
committed by GitHub
parent 5b5b48b11d
commit 39b2ff69d6
7 changed files with 324 additions and 5 deletions

View File

@@ -27,6 +27,7 @@ from transformers import is_torch_available, pipeline, set_seed
from transformers.testing_utils import (
is_flaky,
require_accelerate,
require_auto_gptq,
require_quanto,
require_torch,
require_torch_multi_accelerator,
@@ -3066,6 +3067,43 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertTrue(y_prob > 0.001 and n_prob > 0.001)
self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0)
@require_torch
class TokenHealingTestCase(unittest.TestCase):
@parameterized.expand(
[
(
"square_bracket",
'An example ["like this"] and another example [',
'An example ["like this"] and another example ["',
),
("url", 'The link is <a href="http:', 'The link is <a href="http://'),
# aggressive_healing: "http" shouldn't be replaced with "https"
("aggressive_healing", 'The link is <a href="http', 'The link is <a href="http'),
("trailing_whitespace", "I read a book about ", "I read a book about"),
("nothing_to_heal", "I read a book about", "I read a book about"),
("single_token", "I", "I"),
("empty_prompt", "", ""),
]
)
@require_auto_gptq
def test_prompts(self, name, input, expected):
model_name_or_path = "TheBloke/deepseek-llm-7B-base-GPTQ"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
completion_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map="auto",
trust_remote_code=False,
revision="main",
use_cache=True,
)
input_ids = tokenizer(input, return_tensors="pt").input_ids.to(completion_model.device)
healed_ids = completion_model.heal_tokens(input_ids)
predicted = tokenizer.decode(healed_ids[0], skip_special_tokens=True)
self.assertEqual(predicted, expected)
def test_generate_from_inputs_embeds_with_bos_token_id_is_none(self):
article = "Today a dragon flew over Paris."
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)

View File

@@ -33,7 +33,7 @@ from transformers import (
is_tokenizers_available,
)
from transformers.testing_utils import TOKEN, USER, is_staging_test, require_tokenizers
from transformers.tokenization_utils import Trie
from transformers.tokenization_utils import ExtensionsTrie, Trie
sys.path.append(str(Path(__file__).parent.parent / "utils"))
@@ -274,3 +274,35 @@ class TrieTest(unittest.TestCase):
trie = Trie()
parts = trie.cut_text("ABC", [0, 0, 2, 1, 2, 3])
self.assertEqual(parts, ["AB", "C"])
class ExtensionsTrieTest(unittest.TestCase):
def test_extensions(self):
# Test searching by prefix
trie = ExtensionsTrie()
trie.add("foo")
trie.add("food")
trie.add("foodie")
trie.add("helium")
self.assertEqual(trie.extensions("foo"), ["foo", "food", "foodie"])
self.assertEqual(trie.extensions("helium"), ["helium"])
def test_empty_prefix(self):
trie = ExtensionsTrie()
# Test searching with an empty prefix returns all values
trie.add("hello")
trie.add("bye")
self.assertEqual(trie.extensions(""), ["hello", "bye"])
def test_no_extension_match(self):
trie = ExtensionsTrie()
# Test searching for a prefix that doesn't match any key
with self.assertRaises(KeyError):
trie.extensions("unknown")
def test_update_value(self):
trie = ExtensionsTrie()
# Test updating the value of an existing key
trie.add("hi")
trie.add("hi")
self.assertEqual(trie.extensions("hi"), ["hi"])