[fix] fix token healing tests and usage errors (#33931)
* auto-gptq requirement is removed & model is changed & tokenizer pad token is assigned * values func is changed with extensions & sequence key value bug is fixed * map key value check is added in ExtensionsTree * empty trimmed_ids bug is fixed * tail_id IndexError is fixed * empty trimmed_ids bug fix is updated for failed test * too much specific case for specific tokenizer is removed * input_ids check is updated * require auto-gptq import is removed * key error check is changed with empty list check * empty input_ids check is added * empty trimmed_ids fix is checked with numel function * usage change comments are added * test changes are commented * comment style and quality bugs are fixed * test comment style and quality bug is fixed
This commit is contained in:
@@ -28,7 +28,6 @@ from transformers import AutoConfig, is_torch_available, pipeline, set_seed
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_accelerate,
|
||||
require_auto_gptq,
|
||||
require_optimum_quanto,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
@@ -3912,11 +3911,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
class TokenHealingTestCase(unittest.TestCase):
|
||||
@parameterized.expand(
|
||||
[
|
||||
(
|
||||
"square_bracket",
|
||||
'An example ["like this"] and another example [',
|
||||
'An example ["like this"] and another example ["',
|
||||
),
|
||||
("url", 'The link is <a href="http:', 'The link is <a href="http://'),
|
||||
# aggressive_healing: "http" shouldn't be replaced with "https"
|
||||
("aggressive_healing", 'The link is <a href="http', 'The link is <a href="http'),
|
||||
@@ -3926,9 +3920,8 @@ class TokenHealingTestCase(unittest.TestCase):
|
||||
("empty_prompt", "", ""),
|
||||
]
|
||||
)
|
||||
@require_auto_gptq
|
||||
def test_prompts(self, name, input, expected):
|
||||
model_name_or_path = "TheBloke/deepseek-llm-7B-base-GPTQ"
|
||||
model_name_or_path = "distilbert/distilgpt2"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
|
||||
completion_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path,
|
||||
@@ -3937,9 +3930,16 @@ class TokenHealingTestCase(unittest.TestCase):
|
||||
revision="main",
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
"""
|
||||
tokenizer.pad_token value can be empty but it is required in the latter codes
|
||||
so assigned it here with eos_token
|
||||
"""
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
input_ids = tokenizer(input, return_tensors="pt").input_ids.to(completion_model.device)
|
||||
|
||||
healed_ids = completion_model.heal_tokens(input_ids)
|
||||
healed_ids = completion_model.heal_tokens(input_ids, tokenizer=tokenizer)
|
||||
predicted = tokenizer.decode(healed_ids[0], skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(predicted, expected)
|
||||
|
||||
Reference in New Issue
Block a user