[fix] fix token healing tests and usage errors (#33931)

* auto-gptq requirement is removed & model is changed & tokenizer pad token is assigned

* values func is changed with extensions & sequence key value bug is fixed

* map key value check is added in ExtensionsTree

* empty trimmed_ids bug is fixed

* tail_id IndexError is fixed

* empty trimmed_ids bug fix is updated for failed test

* too much specific case for specific tokenizer is removed

* input_ids check is updated

* require auto-gptq import is removed

* key error check is changed with empty list check

* empty input_ids check is added

* empty trimmed_ids fix is checked with numel function

* usage change comments are added

* test changes are commented

* comment style and quality bugs are fixed

* test comment style and quality bug is fixed
This commit is contained in:
alpertunga-bile
2024-10-16 15:22:55 +03:00
committed by GitHub
parent 9ba021ea75
commit 98bad9c6d6
4 changed files with 44 additions and 17 deletions

View File

@@ -28,7 +28,6 @@ from transformers import AutoConfig, is_torch_available, pipeline, set_seed
from transformers.testing_utils import (
is_flaky,
require_accelerate,
require_auto_gptq,
require_optimum_quanto,
require_torch,
require_torch_gpu,
@@ -3912,11 +3911,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
class TokenHealingTestCase(unittest.TestCase):
@parameterized.expand(
[
(
"square_bracket",
'An example ["like this"] and another example [',
'An example ["like this"] and another example ["',
),
("url", 'The link is <a href="http:', 'The link is <a href="http://'),
# aggressive_healing: "http" shouldn't be replaced with "https"
("aggressive_healing", 'The link is <a href="http', 'The link is <a href="http'),
@@ -3926,9 +3920,8 @@ class TokenHealingTestCase(unittest.TestCase):
("empty_prompt", "", ""),
]
)
@require_auto_gptq
def test_prompts(self, name, input, expected):
model_name_or_path = "TheBloke/deepseek-llm-7B-base-GPTQ"
model_name_or_path = "distilbert/distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
completion_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
@@ -3937,9 +3930,16 @@ class TokenHealingTestCase(unittest.TestCase):
revision="main",
use_cache=True,
)
"""
tokenizer.pad_token value can be empty but it is required in the latter codes
so assigned it here with eos_token
"""
tokenizer.pad_token = tokenizer.eos_token
input_ids = tokenizer(input, return_tensors="pt").input_ids.to(completion_model.device)
healed_ids = completion_model.heal_tokens(input_ids)
healed_ids = completion_model.heal_tokens(input_ids, tokenizer=tokenizer)
predicted = tokenizer.decode(healed_ids[0], skip_special_tokens=True)
self.assertEqual(predicted, expected)