[fix] fix token healing tests and usage errors (#33931)

* auto-gptq requirement is removed & model is changed & tokenizer pad token is assigned

* values func is changed with extensions & sequence key value bug is fixed

* map key value check is added in ExtensionsTree

* empty trimmed_ids bug is fixed

* tail_id IndexError is fixed

* empty trimmed_ids bug fix is updated for failed test

* too much specific case for specific tokenizer is removed

* input_ids check is updated

* require auto-gptq import is removed

* key error check is changed with empty list check

* empty input_ids check is added

* empty trimmed_ids fix is checked with numel function

* usage change comments are added

* test changes are commented

* comment style and quality bugs are fixed

* test comment style and quality bug is fixed
This commit is contained in:
alpertunga-bile
2024-10-16 15:22:55 +03:00
committed by GitHub
parent 9ba021ea75
commit 98bad9c6d6
4 changed files with 44 additions and 17 deletions

View File

@@ -1419,7 +1419,7 @@ class GenerationMixin:
input_ids_length, input_ids_length,
inputs_tensor, inputs_tensor,
): ):
"""Prepared max and min length in generaion configs to avoid clashes between similar attributes""" """Prepared max and min length in generation configs to avoid clashes between similar attributes"""
if generation_config.max_new_tokens is not None: if generation_config.max_new_tokens is not None:
if not has_default_max_length and generation_config.max_length is not None: if not has_default_max_length and generation_config.max_length is not None:
@@ -1662,7 +1662,7 @@ class GenerationMixin:
device: torch.device, device: torch.device,
) -> bool: ) -> bool:
""" """
Prepares the cache for generation (if applicable), given `generate`'s paramaterization. If a cache is Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache is
instantiated, writes it to `model_kwargs`, under the name expected by the model. instantiated, writes it to `model_kwargs`, under the name expected by the model.
""" """
@@ -1925,7 +1925,7 @@ class GenerationMixin:
deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`. deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`.
assistant_model (`PreTrainedModel`, *optional*): assistant_model (`PreTrainedModel`, *optional*):
An assistant model that can be used to accelerate generation. The assistant model must have the exact An assistant model that can be used to accelerate generation. The assistant model must have the exact
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistant model
is much faster than running generation with the model you're calling generate from. As such, the is much faster than running generation with the model you're calling generate from. As such, the
assistant model should be much smaller. assistant model should be much smaller.
streamer (`BaseStreamer`, *optional*): streamer (`BaseStreamer`, *optional*):
@@ -2442,7 +2442,15 @@ class GenerationMixin:
# replace bos with pad to not condition healing on it # replace bos with pad to not condition healing on it
input_ids = torch.where(input_ids == bos_token_id, pad_token_id, input_ids) input_ids = torch.where(input_ids == bos_token_id, pad_token_id, input_ids)
"""
the latter code assumes the input_ids is not empty,
input_id has to be checked if contains elements
"""
if input_ids.numel() == 0:
return input_ids
tail_ids = input_ids[:, -1].tolist() tail_ids = input_ids[:, -1].tolist()
space_tok = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(" "))[0] space_tok = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(" "))[0]
# tail tokens are used for a prefix search, thus, whitespaces are replaced with # tail tokens are used for a prefix search, thus, whitespaces are replaced with
# their tokenization (e.g. 'Ġ') to enable search for tokens prefixed with a whitespace # their tokenization (e.g. 'Ġ') to enable search for tokens prefixed with a whitespace
@@ -2454,7 +2462,14 @@ class GenerationMixin:
continue # skip empty sequences (all pad ids) continue # skip empty sequences (all pad ids)
# apply bias for alternatives (extensions) to the tail token # apply bias for alternatives (extensions) to the tail token
seq_bias = {(alt_tok,): 10.0 for alt_tok in vocab_trie.values(prefix=tail_tok)} """
seq_bias key has to be tuple with int so have to use
tokenizer function to convert str to int
"""
seq_bias = {
(tokenizer.convert_tokens_to_ids(alt_tok),): 10.0 for alt_tok in vocab_trie.extensions(prefix=tail_tok)
}
if len(seq_bias) == 1: if len(seq_bias) == 1:
continue # skip if there are no token alternatives to heal with continue # skip if there are no token alternatives to heal with
@@ -2463,6 +2478,14 @@ class GenerationMixin:
generation_config.update(sequence_bias=seq_bias) generation_config.update(sequence_bias=seq_bias)
trimmed_ids = batch_ids[:-1] trimmed_ids = batch_ids[:-1]
"""
the latter code assumes trimmed_ids is not empty
so have to check the its element count
"""
if trimmed_ids.numel() == 0:
continue
# if the prompt is a single (non-pad) token, regenerate from bos # if the prompt is a single (non-pad) token, regenerate from bos
if len(batch_ids[batch_ids != pad_token_id]) == 1: if len(batch_ids[batch_ids != pad_token_id]) == 1:
trimmed_ids[-1] = bos_token_id trimmed_ids[-1] = bos_token_id
@@ -2915,7 +2938,7 @@ class GenerationMixin:
output_attentions=output_attentions, output_attentions=output_attentions,
) )
# This is essential to avoid having a last reference to the big past K-V and double the necesary memory # This is essential to avoid having a last reference to the big past K-V and double the necessary memory
# in the next loop # in the next loop
del next_model_inputs del next_model_inputs
@@ -3658,7 +3681,7 @@ class GenerationMixin:
) )
# initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in # initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in
# the same group don't produce same tokens everytime. # the same group don't produce same tokens every time.
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
beam_scores[:, ::num_sub_beams] = 0 beam_scores[:, ::num_sub_beams] = 0
beam_scores = beam_scores.view((batch_size * num_beams,)) beam_scores = beam_scores.view((batch_size * num_beams,))

View File

@@ -316,6 +316,9 @@ class ExtensionsTrie(Trie):
""" """
node = self.data node = self.data
for char in token: for char in token:
if char not in node:
break
node = node[char] node = node[char]
return node return node

View File

@@ -28,7 +28,6 @@ from transformers import AutoConfig, is_torch_available, pipeline, set_seed
from transformers.testing_utils import ( from transformers.testing_utils import (
is_flaky, is_flaky,
require_accelerate, require_accelerate,
require_auto_gptq,
require_optimum_quanto, require_optimum_quanto,
require_torch, require_torch,
require_torch_gpu, require_torch_gpu,
@@ -3912,11 +3911,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
class TokenHealingTestCase(unittest.TestCase): class TokenHealingTestCase(unittest.TestCase):
@parameterized.expand( @parameterized.expand(
[ [
(
"square_bracket",
'An example ["like this"] and another example [',
'An example ["like this"] and another example ["',
),
("url", 'The link is <a href="http:', 'The link is <a href="http://'), ("url", 'The link is <a href="http:', 'The link is <a href="http://'),
# aggressive_healing: "http" shouldn't be replaced with "https" # aggressive_healing: "http" shouldn't be replaced with "https"
("aggressive_healing", 'The link is <a href="http', 'The link is <a href="http'), ("aggressive_healing", 'The link is <a href="http', 'The link is <a href="http'),
@@ -3926,9 +3920,8 @@ class TokenHealingTestCase(unittest.TestCase):
("empty_prompt", "", ""), ("empty_prompt", "", ""),
] ]
) )
@require_auto_gptq
def test_prompts(self, name, input, expected): def test_prompts(self, name, input, expected):
model_name_or_path = "TheBloke/deepseek-llm-7B-base-GPTQ" model_name_or_path = "distilbert/distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
completion_model = AutoModelForCausalLM.from_pretrained( completion_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path, model_name_or_path,
@@ -3937,9 +3930,16 @@ class TokenHealingTestCase(unittest.TestCase):
revision="main", revision="main",
use_cache=True, use_cache=True,
) )
"""
tokenizer.pad_token value can be empty but it is required in the latter codes
so assigned it here with eos_token
"""
tokenizer.pad_token = tokenizer.eos_token
input_ids = tokenizer(input, return_tensors="pt").input_ids.to(completion_model.device) input_ids = tokenizer(input, return_tensors="pt").input_ids.to(completion_model.device)
healed_ids = completion_model.heal_tokens(input_ids) healed_ids = completion_model.heal_tokens(input_ids, tokenizer=tokenizer)
predicted = tokenizer.decode(healed_ids[0], skip_special_tokens=True) predicted = tokenizer.decode(healed_ids[0], skip_special_tokens=True)
self.assertEqual(predicted, expected) self.assertEqual(predicted, expected)

View File

@@ -325,8 +325,9 @@ class ExtensionsTrieTest(unittest.TestCase):
def test_no_extension_match(self): def test_no_extension_match(self):
trie = ExtensionsTrie() trie = ExtensionsTrie()
# Test searching for a prefix that doesn't match any key # Test searching for a prefix that doesn't match any key
with self.assertRaises(KeyError): values = trie.extensions("unknown")
trie.extensions("unknown")
self.assertEqual(len(values), 0)
def test_update_value(self): def test_update_value(self):
trie = ExtensionsTrie() trie = ExtensionsTrie()