From 98bad9c6d6c9ac98b42164fd882f94d4b5bfa4d7 Mon Sep 17 00:00:00 2001 From: alpertunga-bile Date: Wed, 16 Oct 2024 15:22:55 +0300 Subject: [PATCH] [fix] fix token healing tests and usage errors (#33931) * auto-gptq requirement is removed & model is changed & tokenizer pad token is assigned * values func is changed with extensions & sequence key value bug is fixed * map key value check is added in ExtensionsTree * empty trimmed_ids bug is fixed * tail_id IndexError is fixed * empty trimmed_ids bug fix is updated for failed test * too much specific case for specific tokenizer is removed * input_ids check is updated * require auto-gptq import is removed * key error check is changed with empty list check * empty input_ids check is added * empty trimmed_ids fix is checked with numel function * usage change comments are added * test changes are commented * comment style and quality bugs are fixed * test comment style and quality bug is fixed --- src/transformers/generation/utils.py | 35 +++++++++++++++++++++----- src/transformers/tokenization_utils.py | 3 +++ tests/generation/test_utils.py | 18 ++++++------- tests/utils/test_tokenization_utils.py | 5 ++-- 4 files changed, 44 insertions(+), 17 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6d71b754d6..86ea702dd9 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1419,7 +1419,7 @@ class GenerationMixin: input_ids_length, inputs_tensor, ): - """Prepared max and min length in generaion configs to avoid clashes between similar attributes""" + """Prepared max and min length in generation configs to avoid clashes between similar attributes""" if generation_config.max_new_tokens is not None: if not has_default_max_length and generation_config.max_length is not None: @@ -1662,7 +1662,7 @@ class GenerationMixin: device: torch.device, ) -> bool: """ - Prepares the cache for generation (if applicable), given `generate`'s paramaterization. If a cache is + Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache is instantiated, writes it to `model_kwargs`, under the name expected by the model. """ @@ -1925,7 +1925,7 @@ class GenerationMixin: deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`. assistant_model (`PreTrainedModel`, *optional*): An assistant model that can be used to accelerate generation. The assistant model must have the exact - same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model + same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistant model is much faster than running generation with the model you're calling generate from. As such, the assistant model should be much smaller. streamer (`BaseStreamer`, *optional*): @@ -2442,7 +2442,15 @@ class GenerationMixin: # replace bos with pad to not condition healing on it input_ids = torch.where(input_ids == bos_token_id, pad_token_id, input_ids) + """ + the latter code assumes the input_ids is not empty, + input_id has to be checked if contains elements + """ + if input_ids.numel() == 0: + return input_ids + tail_ids = input_ids[:, -1].tolist() + space_tok = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(" "))[0] # tail tokens are used for a prefix search, thus, whitespaces are replaced with # their tokenization (e.g. 'Ġ') to enable search for tokens prefixed with a whitespace @@ -2454,7 +2462,14 @@ class GenerationMixin: continue # skip empty sequences (all pad ids) # apply bias for alternatives (extensions) to the tail token - seq_bias = {(alt_tok,): 10.0 for alt_tok in vocab_trie.values(prefix=tail_tok)} + """ + seq_bias key has to be tuple with int so have to use + tokenizer function to convert str to int + """ + seq_bias = { + (tokenizer.convert_tokens_to_ids(alt_tok),): 10.0 for alt_tok in vocab_trie.extensions(prefix=tail_tok) + } + if len(seq_bias) == 1: continue # skip if there are no token alternatives to heal with @@ -2463,6 +2478,14 @@ class GenerationMixin: generation_config.update(sequence_bias=seq_bias) trimmed_ids = batch_ids[:-1] + + """ + the latter code assumes trimmed_ids is not empty + so have to check the its element count + """ + if trimmed_ids.numel() == 0: + continue + # if the prompt is a single (non-pad) token, regenerate from bos if len(batch_ids[batch_ids != pad_token_id]) == 1: trimmed_ids[-1] = bos_token_id @@ -2915,7 +2938,7 @@ class GenerationMixin: output_attentions=output_attentions, ) - # This is essential to avoid having a last reference to the big past K-V and double the necesary memory + # This is essential to avoid having a last reference to the big past K-V and double the necessary memory # in the next loop del next_model_inputs @@ -3658,7 +3681,7 @@ class GenerationMixin: ) # initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in - # the same group don't produce same tokens everytime. + # the same group don't produce same tokens every time. beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) beam_scores[:, ::num_sub_beams] = 0 beam_scores = beam_scores.view((batch_size * num_beams,)) diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index df13a029a6..d2433868cf 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -316,6 +316,9 @@ class ExtensionsTrie(Trie): """ node = self.data for char in token: + if char not in node: + break + node = node[char] return node diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 5165e43c09..6766fa22b9 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -28,7 +28,6 @@ from transformers import AutoConfig, is_torch_available, pipeline, set_seed from transformers.testing_utils import ( is_flaky, require_accelerate, - require_auto_gptq, require_optimum_quanto, require_torch, require_torch_gpu, @@ -3912,11 +3911,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi class TokenHealingTestCase(unittest.TestCase): @parameterized.expand( [ - ( - "square_bracket", - 'An example ["like this"] and another example [', - 'An example ["like this"] and another example ["', - ), ("url", 'The link is