From 44af935ec58f417febd72e43baeae024d0ade18c Mon Sep 17 00:00:00 2001 From: xinpengzz Date: Thu, 28 Nov 2024 22:04:24 +0800 Subject: [PATCH] Refine the code of Universal Assisted Generation (#34823) * removed the useless attritbutes * add configs for window size * fixed the wrong kwargs * added docstring --- .../generation/candidate_generator.py | 15 ++++++--------- .../generation/configuration_utils.py | 11 +++++++++++ tests/generation/test_utils.py | 2 +- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index df213b458c..7cab88a4bc 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -310,10 +310,9 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator): self.target_tokenizer = target_tokenizer self.assistant_tokenizer = assistant_tokenizer - self.prev_tokens = None self.prev_assistant_ids = None - self.target_lookbehind = 10 - self.assistant_lookbehind = 10 + self.target_lookbehind = assistant_model.generation_config.target_lookbehind + self.assistant_lookbehind = assistant_model.generation_config.assistant_lookbehind @staticmethod def _get_longest_diag_dict(input_matrix, nonzero_idx): @@ -450,9 +449,9 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator): # Since re-encoding the tokens may result in tokenization discrepancies, we use 2 look behind values # (one for each conversion) which mark where to start looking for the overlap between the # source and target encodings, to ensure the new tokens include the correct prompt suffix. - if self.prev_tokens is not None and self.prev_target_ids.shape[1] > self.target_lookbehind: + if self.prev_assistant_ids is not None and input_ids.shape[1] > self.target_lookbehind: # input_ids contains all target prompt input ids and some new target input ids - start_index_in_target_window = self.prev_target_ids.shape[1] - self.target_lookbehind + start_index_in_target_window = input_ids.shape[1] - self.target_lookbehind new_assistant_ids = self.convert_source_tokens_to_target_tokens( input_ids[:, start_index_in_target_window:], **convert_kwargs @@ -485,7 +484,6 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator): else: assistant_input_ids = self.convert_source_tokens_to_target_tokens(input_ids, **convert_kwargs) - self.prev_target_ids = input_ids self.prev_assistant_ids = assistant_input_ids new_cur_len = assistant_input_ids.shape[-1] @@ -520,6 +518,8 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator): num_prev_assistant = self.prev_assistant_ids.shape[1] start_assistant_look_index = num_prev_assistant - self.assistant_lookbehind + if start_assistant_look_index < 0: + start_assistant_look_index = 0 new_target_ids_from_window = self.convert_source_tokens_to_target_tokens( assistant_output.sequences[:, start_assistant_look_index:], @@ -543,14 +543,11 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator): # edge case: in case of no intersection between prompt and new_target_ids new_target_ids = torch.cat([new_target_ids, new_target_ids_from_window], dim=-1) - self.prev_target_ids = input_ids - if hasattr(self.generation_config, "max_length"): new_target_ids = new_target_ids[:, : self.generation_config.max_length] # 3. Update variables for the next round of candidate generation self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values - self.prev_tokens = assistant_output.sequences # 4. Prepare variables for output if input_ids.shape[1] >= new_target_ids.shape[1]: diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index de62ee767a..cbc445308a 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -360,6 +360,14 @@ class GenerationConfig(PushToHubMixin): assistant_early_exit(`int`, *optional*): If set to a positive integer, early exit of the model will be used as an assistant. Can only be used with models that support early exit (i.e. models where logits from intermediate layers can be interpreted by the LM head). + assistant_lookbehind(`int`, *optional*, defaults to 10): + If set to a positive integer, the re-encodeing process will additionally consider the last `assistant_lookbehind` assistant tokens + to correctly align tokens. Can only be used with different tokenizers in speculative decoding. + See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details. + target_lookbehind(`int`, *optional*, defaults to 10): + If set to a positive integer, the re-encodeing process will additionally consider the last `target_lookbehind` target tokens + to correctly align tokens. Can only be used with different tokenizers in speculative decoding. + See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details. > Wild card @@ -460,6 +468,9 @@ class GenerationConfig(PushToHubMixin): self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None) self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None) self.assistant_early_exit = kwargs.pop("assistant_early_exit", None) + ## assistant generation for different tokenizers, the windows size for assistant/target model + self.assistant_lookbehind = kwargs.pop("assistant_lookbehind", 10) + self.target_lookbehind = kwargs.pop("target_lookbehind", 10) # Wild card self.generation_kwargs = kwargs.pop("generation_kwargs", {}) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 0605ea7939..e403a528a8 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3812,7 +3812,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi do_sample=False, max_new_tokens=max_new_tokens_item, assistant_model=draft_model, - target_tokenizer=target_tokenizer, + tokenizer=target_tokenizer, assistant_tokenizer=assistant_tokenizer, )