[fix] fix token healing tests and usage errors (#33931)
* auto-gptq requirement is removed & model is changed & tokenizer pad token is assigned * values func is changed with extensions & sequence key value bug is fixed * map key value check is added in ExtensionsTree * empty trimmed_ids bug is fixed * tail_id IndexError is fixed * empty trimmed_ids bug fix is updated for failed test * too much specific case for specific tokenizer is removed * input_ids check is updated * require auto-gptq import is removed * key error check is changed with empty list check * empty input_ids check is added * empty trimmed_ids fix is checked with numel function * usage change comments are added * test changes are commented * comment style and quality bugs are fixed * test comment style and quality bug is fixed
This commit is contained in:
@@ -1419,7 +1419,7 @@ class GenerationMixin:
|
|||||||
input_ids_length,
|
input_ids_length,
|
||||||
inputs_tensor,
|
inputs_tensor,
|
||||||
):
|
):
|
||||||
"""Prepared max and min length in generaion configs to avoid clashes between similar attributes"""
|
"""Prepared max and min length in generation configs to avoid clashes between similar attributes"""
|
||||||
|
|
||||||
if generation_config.max_new_tokens is not None:
|
if generation_config.max_new_tokens is not None:
|
||||||
if not has_default_max_length and generation_config.max_length is not None:
|
if not has_default_max_length and generation_config.max_length is not None:
|
||||||
@@ -1662,7 +1662,7 @@ class GenerationMixin:
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Prepares the cache for generation (if applicable), given `generate`'s paramaterization. If a cache is
|
Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache is
|
||||||
instantiated, writes it to `model_kwargs`, under the name expected by the model.
|
instantiated, writes it to `model_kwargs`, under the name expected by the model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -1925,7 +1925,7 @@ class GenerationMixin:
|
|||||||
deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`.
|
deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`.
|
||||||
assistant_model (`PreTrainedModel`, *optional*):
|
assistant_model (`PreTrainedModel`, *optional*):
|
||||||
An assistant model that can be used to accelerate generation. The assistant model must have the exact
|
An assistant model that can be used to accelerate generation. The assistant model must have the exact
|
||||||
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
|
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistant model
|
||||||
is much faster than running generation with the model you're calling generate from. As such, the
|
is much faster than running generation with the model you're calling generate from. As such, the
|
||||||
assistant model should be much smaller.
|
assistant model should be much smaller.
|
||||||
streamer (`BaseStreamer`, *optional*):
|
streamer (`BaseStreamer`, *optional*):
|
||||||
@@ -2442,7 +2442,15 @@ class GenerationMixin:
|
|||||||
# replace bos with pad to not condition healing on it
|
# replace bos with pad to not condition healing on it
|
||||||
input_ids = torch.where(input_ids == bos_token_id, pad_token_id, input_ids)
|
input_ids = torch.where(input_ids == bos_token_id, pad_token_id, input_ids)
|
||||||
|
|
||||||
|
"""
|
||||||
|
the latter code assumes the input_ids is not empty,
|
||||||
|
input_id has to be checked if contains elements
|
||||||
|
"""
|
||||||
|
if input_ids.numel() == 0:
|
||||||
|
return input_ids
|
||||||
|
|
||||||
tail_ids = input_ids[:, -1].tolist()
|
tail_ids = input_ids[:, -1].tolist()
|
||||||
|
|
||||||
space_tok = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(" "))[0]
|
space_tok = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(" "))[0]
|
||||||
# tail tokens are used for a prefix search, thus, whitespaces are replaced with
|
# tail tokens are used for a prefix search, thus, whitespaces are replaced with
|
||||||
# their tokenization (e.g. 'Ġ') to enable search for tokens prefixed with a whitespace
|
# their tokenization (e.g. 'Ġ') to enable search for tokens prefixed with a whitespace
|
||||||
@@ -2454,7 +2462,14 @@ class GenerationMixin:
|
|||||||
continue # skip empty sequences (all pad ids)
|
continue # skip empty sequences (all pad ids)
|
||||||
|
|
||||||
# apply bias for alternatives (extensions) to the tail token
|
# apply bias for alternatives (extensions) to the tail token
|
||||||
seq_bias = {(alt_tok,): 10.0 for alt_tok in vocab_trie.values(prefix=tail_tok)}
|
"""
|
||||||
|
seq_bias key has to be tuple with int so have to use
|
||||||
|
tokenizer function to convert str to int
|
||||||
|
"""
|
||||||
|
seq_bias = {
|
||||||
|
(tokenizer.convert_tokens_to_ids(alt_tok),): 10.0 for alt_tok in vocab_trie.extensions(prefix=tail_tok)
|
||||||
|
}
|
||||||
|
|
||||||
if len(seq_bias) == 1:
|
if len(seq_bias) == 1:
|
||||||
continue # skip if there are no token alternatives to heal with
|
continue # skip if there are no token alternatives to heal with
|
||||||
|
|
||||||
@@ -2463,6 +2478,14 @@ class GenerationMixin:
|
|||||||
generation_config.update(sequence_bias=seq_bias)
|
generation_config.update(sequence_bias=seq_bias)
|
||||||
|
|
||||||
trimmed_ids = batch_ids[:-1]
|
trimmed_ids = batch_ids[:-1]
|
||||||
|
|
||||||
|
"""
|
||||||
|
the latter code assumes trimmed_ids is not empty
|
||||||
|
so have to check the its element count
|
||||||
|
"""
|
||||||
|
if trimmed_ids.numel() == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
# if the prompt is a single (non-pad) token, regenerate from bos
|
# if the prompt is a single (non-pad) token, regenerate from bos
|
||||||
if len(batch_ids[batch_ids != pad_token_id]) == 1:
|
if len(batch_ids[batch_ids != pad_token_id]) == 1:
|
||||||
trimmed_ids[-1] = bos_token_id
|
trimmed_ids[-1] = bos_token_id
|
||||||
@@ -2915,7 +2938,7 @@ class GenerationMixin:
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
# This is essential to avoid having a last reference to the big past K-V and double the necesary memory
|
# This is essential to avoid having a last reference to the big past K-V and double the necessary memory
|
||||||
# in the next loop
|
# in the next loop
|
||||||
del next_model_inputs
|
del next_model_inputs
|
||||||
|
|
||||||
@@ -3658,7 +3681,7 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in
|
# initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in
|
||||||
# the same group don't produce same tokens everytime.
|
# the same group don't produce same tokens every time.
|
||||||
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
||||||
beam_scores[:, ::num_sub_beams] = 0
|
beam_scores[:, ::num_sub_beams] = 0
|
||||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||||
|
|||||||
@@ -316,6 +316,9 @@ class ExtensionsTrie(Trie):
|
|||||||
"""
|
"""
|
||||||
node = self.data
|
node = self.data
|
||||||
for char in token:
|
for char in token:
|
||||||
|
if char not in node:
|
||||||
|
break
|
||||||
|
|
||||||
node = node[char]
|
node = node[char]
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ from transformers import AutoConfig, is_torch_available, pipeline, set_seed
|
|||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_flaky,
|
is_flaky,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
require_auto_gptq,
|
|
||||||
require_optimum_quanto,
|
require_optimum_quanto,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
@@ -3912,11 +3911,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
class TokenHealingTestCase(unittest.TestCase):
|
class TokenHealingTestCase(unittest.TestCase):
|
||||||
@parameterized.expand(
|
@parameterized.expand(
|
||||||
[
|
[
|
||||||
(
|
|
||||||
"square_bracket",
|
|
||||||
'An example ["like this"] and another example [',
|
|
||||||
'An example ["like this"] and another example ["',
|
|
||||||
),
|
|
||||||
("url", 'The link is <a href="http:', 'The link is <a href="http://'),
|
("url", 'The link is <a href="http:', 'The link is <a href="http://'),
|
||||||
# aggressive_healing: "http" shouldn't be replaced with "https"
|
# aggressive_healing: "http" shouldn't be replaced with "https"
|
||||||
("aggressive_healing", 'The link is <a href="http', 'The link is <a href="http'),
|
("aggressive_healing", 'The link is <a href="http', 'The link is <a href="http'),
|
||||||
@@ -3926,9 +3920,8 @@ class TokenHealingTestCase(unittest.TestCase):
|
|||||||
("empty_prompt", "", ""),
|
("empty_prompt", "", ""),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@require_auto_gptq
|
|
||||||
def test_prompts(self, name, input, expected):
|
def test_prompts(self, name, input, expected):
|
||||||
model_name_or_path = "TheBloke/deepseek-llm-7B-base-GPTQ"
|
model_name_or_path = "distilbert/distilgpt2"
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
|
||||||
completion_model = AutoModelForCausalLM.from_pretrained(
|
completion_model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_name_or_path,
|
model_name_or_path,
|
||||||
@@ -3937,9 +3930,16 @@ class TokenHealingTestCase(unittest.TestCase):
|
|||||||
revision="main",
|
revision="main",
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
tokenizer.pad_token value can be empty but it is required in the latter codes
|
||||||
|
so assigned it here with eos_token
|
||||||
|
"""
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
input_ids = tokenizer(input, return_tensors="pt").input_ids.to(completion_model.device)
|
input_ids = tokenizer(input, return_tensors="pt").input_ids.to(completion_model.device)
|
||||||
|
|
||||||
healed_ids = completion_model.heal_tokens(input_ids)
|
healed_ids = completion_model.heal_tokens(input_ids, tokenizer=tokenizer)
|
||||||
predicted = tokenizer.decode(healed_ids[0], skip_special_tokens=True)
|
predicted = tokenizer.decode(healed_ids[0], skip_special_tokens=True)
|
||||||
|
|
||||||
self.assertEqual(predicted, expected)
|
self.assertEqual(predicted, expected)
|
||||||
|
|||||||
@@ -325,8 +325,9 @@ class ExtensionsTrieTest(unittest.TestCase):
|
|||||||
def test_no_extension_match(self):
|
def test_no_extension_match(self):
|
||||||
trie = ExtensionsTrie()
|
trie = ExtensionsTrie()
|
||||||
# Test searching for a prefix that doesn't match any key
|
# Test searching for a prefix that doesn't match any key
|
||||||
with self.assertRaises(KeyError):
|
values = trie.extensions("unknown")
|
||||||
trie.extensions("unknown")
|
|
||||||
|
self.assertEqual(len(values), 0)
|
||||||
|
|
||||||
def test_update_value(self):
|
def test_update_value(self):
|
||||||
trie = ExtensionsTrie()
|
trie = ExtensionsTrie()
|
||||||
|
|||||||
Reference in New Issue
Block a user