Token healing (#30081)
* token healing impl + trie with extensions * make fixup * prefix-robust space tokenization * examples readme and requirements * make fixup * allow input prompt and model * redundant defaults * Specialized Trie * make fixup * updated tests with new inherited Tree * input ids to auto device_map * rm unused import * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * naming convention * Revert "naming convention" This reverts commit dd39d9c5b7a969e2d8a8d2a8e54f121b82dc44f0. * naming convention * last -hopefully- changes --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -27,6 +27,7 @@ from transformers import is_torch_available, pipeline, set_seed
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_accelerate,
|
||||
require_auto_gptq,
|
||||
require_quanto,
|
||||
require_torch,
|
||||
require_torch_multi_accelerator,
|
||||
@@ -3066,6 +3067,43 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertTrue(y_prob > 0.001 and n_prob > 0.001)
|
||||
self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TokenHealingTestCase(unittest.TestCase):
|
||||
@parameterized.expand(
|
||||
[
|
||||
(
|
||||
"square_bracket",
|
||||
'An example ["like this"] and another example [',
|
||||
'An example ["like this"] and another example ["',
|
||||
),
|
||||
("url", 'The link is <a href="http:', 'The link is <a href="http://'),
|
||||
# aggressive_healing: "http" shouldn't be replaced with "https"
|
||||
("aggressive_healing", 'The link is <a href="http', 'The link is <a href="http'),
|
||||
("trailing_whitespace", "I read a book about ", "I read a book about"),
|
||||
("nothing_to_heal", "I read a book about", "I read a book about"),
|
||||
("single_token", "I", "I"),
|
||||
("empty_prompt", "", ""),
|
||||
]
|
||||
)
|
||||
@require_auto_gptq
|
||||
def test_prompts(self, name, input, expected):
|
||||
model_name_or_path = "TheBloke/deepseek-llm-7B-base-GPTQ"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
|
||||
completion_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path,
|
||||
device_map="auto",
|
||||
trust_remote_code=False,
|
||||
revision="main",
|
||||
use_cache=True,
|
||||
)
|
||||
input_ids = tokenizer(input, return_tensors="pt").input_ids.to(completion_model.device)
|
||||
|
||||
healed_ids = completion_model.heal_tokens(input_ids)
|
||||
predicted = tokenizer.decode(healed_ids[0], skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(predicted, expected)
|
||||
|
||||
def test_generate_from_inputs_embeds_with_bos_token_id_is_none(self):
|
||||
article = "Today a dragon flew over Paris."
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user