[fix] fix token healing tests and usage errors (#33931)
* auto-gptq requirement is removed & model is changed & tokenizer pad token is assigned * values func is changed with extensions & sequence key value bug is fixed * map key value check is added in ExtensionsTree * empty trimmed_ids bug is fixed * tail_id IndexError is fixed * empty trimmed_ids bug fix is updated for failed test * too much specific case for specific tokenizer is removed * input_ids check is updated * require auto-gptq import is removed * key error check is changed with empty list check * empty input_ids check is added * empty trimmed_ids fix is checked with numel function * usage change comments are added * test changes are commented * comment style and quality bugs are fixed * test comment style and quality bug is fixed
This commit is contained in:
@@ -1419,7 +1419,7 @@ class GenerationMixin:
|
||||
input_ids_length,
|
||||
inputs_tensor,
|
||||
):
|
||||
"""Prepared max and min length in generaion configs to avoid clashes between similar attributes"""
|
||||
"""Prepared max and min length in generation configs to avoid clashes between similar attributes"""
|
||||
|
||||
if generation_config.max_new_tokens is not None:
|
||||
if not has_default_max_length and generation_config.max_length is not None:
|
||||
@@ -1662,7 +1662,7 @@ class GenerationMixin:
|
||||
device: torch.device,
|
||||
) -> bool:
|
||||
"""
|
||||
Prepares the cache for generation (if applicable), given `generate`'s paramaterization. If a cache is
|
||||
Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache is
|
||||
instantiated, writes it to `model_kwargs`, under the name expected by the model.
|
||||
"""
|
||||
|
||||
@@ -1925,7 +1925,7 @@ class GenerationMixin:
|
||||
deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`.
|
||||
assistant_model (`PreTrainedModel`, *optional*):
|
||||
An assistant model that can be used to accelerate generation. The assistant model must have the exact
|
||||
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
|
||||
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistant model
|
||||
is much faster than running generation with the model you're calling generate from. As such, the
|
||||
assistant model should be much smaller.
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
@@ -2442,7 +2442,15 @@ class GenerationMixin:
|
||||
# replace bos with pad to not condition healing on it
|
||||
input_ids = torch.where(input_ids == bos_token_id, pad_token_id, input_ids)
|
||||
|
||||
"""
|
||||
the latter code assumes the input_ids is not empty,
|
||||
input_id has to be checked if contains elements
|
||||
"""
|
||||
if input_ids.numel() == 0:
|
||||
return input_ids
|
||||
|
||||
tail_ids = input_ids[:, -1].tolist()
|
||||
|
||||
space_tok = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(" "))[0]
|
||||
# tail tokens are used for a prefix search, thus, whitespaces are replaced with
|
||||
# their tokenization (e.g. 'Ġ') to enable search for tokens prefixed with a whitespace
|
||||
@@ -2454,7 +2462,14 @@ class GenerationMixin:
|
||||
continue # skip empty sequences (all pad ids)
|
||||
|
||||
# apply bias for alternatives (extensions) to the tail token
|
||||
seq_bias = {(alt_tok,): 10.0 for alt_tok in vocab_trie.values(prefix=tail_tok)}
|
||||
"""
|
||||
seq_bias key has to be tuple with int so have to use
|
||||
tokenizer function to convert str to int
|
||||
"""
|
||||
seq_bias = {
|
||||
(tokenizer.convert_tokens_to_ids(alt_tok),): 10.0 for alt_tok in vocab_trie.extensions(prefix=tail_tok)
|
||||
}
|
||||
|
||||
if len(seq_bias) == 1:
|
||||
continue # skip if there are no token alternatives to heal with
|
||||
|
||||
@@ -2463,6 +2478,14 @@ class GenerationMixin:
|
||||
generation_config.update(sequence_bias=seq_bias)
|
||||
|
||||
trimmed_ids = batch_ids[:-1]
|
||||
|
||||
"""
|
||||
the latter code assumes trimmed_ids is not empty
|
||||
so have to check the its element count
|
||||
"""
|
||||
if trimmed_ids.numel() == 0:
|
||||
continue
|
||||
|
||||
# if the prompt is a single (non-pad) token, regenerate from bos
|
||||
if len(batch_ids[batch_ids != pad_token_id]) == 1:
|
||||
trimmed_ids[-1] = bos_token_id
|
||||
@@ -2915,7 +2938,7 @@ class GenerationMixin:
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
# This is essential to avoid having a last reference to the big past K-V and double the necesary memory
|
||||
# This is essential to avoid having a last reference to the big past K-V and double the necessary memory
|
||||
# in the next loop
|
||||
del next_model_inputs
|
||||
|
||||
@@ -3658,7 +3681,7 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
# initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in
|
||||
# the same group don't produce same tokens everytime.
|
||||
# the same group don't produce same tokens every time.
|
||||
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
||||
beam_scores[:, ::num_sub_beams] = 0
|
||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
|
||||
@@ -316,6 +316,9 @@ class ExtensionsTrie(Trie):
|
||||
"""
|
||||
node = self.data
|
||||
for char in token:
|
||||
if char not in node:
|
||||
break
|
||||
|
||||
node = node[char]
|
||||
return node
|
||||
|
||||
|
||||
@@ -28,7 +28,6 @@ from transformers import AutoConfig, is_torch_available, pipeline, set_seed
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_accelerate,
|
||||
require_auto_gptq,
|
||||
require_optimum_quanto,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
@@ -3912,11 +3911,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
class TokenHealingTestCase(unittest.TestCase):
|
||||
@parameterized.expand(
|
||||
[
|
||||
(
|
||||
"square_bracket",
|
||||
'An example ["like this"] and another example [',
|
||||
'An example ["like this"] and another example ["',
|
||||
),
|
||||
("url", 'The link is <a href="http:', 'The link is <a href="http://'),
|
||||
# aggressive_healing: "http" shouldn't be replaced with "https"
|
||||
("aggressive_healing", 'The link is <a href="http', 'The link is <a href="http'),
|
||||
@@ -3926,9 +3920,8 @@ class TokenHealingTestCase(unittest.TestCase):
|
||||
("empty_prompt", "", ""),
|
||||
]
|
||||
)
|
||||
@require_auto_gptq
|
||||
def test_prompts(self, name, input, expected):
|
||||
model_name_or_path = "TheBloke/deepseek-llm-7B-base-GPTQ"
|
||||
model_name_or_path = "distilbert/distilgpt2"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
|
||||
completion_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path,
|
||||
@@ -3937,9 +3930,16 @@ class TokenHealingTestCase(unittest.TestCase):
|
||||
revision="main",
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
"""
|
||||
tokenizer.pad_token value can be empty but it is required in the latter codes
|
||||
so assigned it here with eos_token
|
||||
"""
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
input_ids = tokenizer(input, return_tensors="pt").input_ids.to(completion_model.device)
|
||||
|
||||
healed_ids = completion_model.heal_tokens(input_ids)
|
||||
healed_ids = completion_model.heal_tokens(input_ids, tokenizer=tokenizer)
|
||||
predicted = tokenizer.decode(healed_ids[0], skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(predicted, expected)
|
||||
|
||||
@@ -325,8 +325,9 @@ class ExtensionsTrieTest(unittest.TestCase):
|
||||
def test_no_extension_match(self):
|
||||
trie = ExtensionsTrie()
|
||||
# Test searching for a prefix that doesn't match any key
|
||||
with self.assertRaises(KeyError):
|
||||
trie.extensions("unknown")
|
||||
values = trie.extensions("unknown")
|
||||
|
||||
self.assertEqual(len(values), 0)
|
||||
|
||||
def test_update_value(self):
|
||||
trie = ExtensionsTrie()
|
||||
|
||||
Reference in New Issue
Block a user