Refine the code of Universal Assisted Generation (#34823)

* removed the useless attritbutes

* add configs for window size

* fixed the wrong kwargs

* added docstring
This commit is contained in:
xinpengzz
2024-11-28 22:04:24 +08:00
committed by GitHub
parent 2b053fdf1a
commit 44af935ec5
3 changed files with 18 additions and 10 deletions

View File

@@ -310,10 +310,9 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
self.target_tokenizer = target_tokenizer
self.assistant_tokenizer = assistant_tokenizer
self.prev_tokens = None
self.prev_assistant_ids = None
self.target_lookbehind = 10
self.assistant_lookbehind = 10
self.target_lookbehind = assistant_model.generation_config.target_lookbehind
self.assistant_lookbehind = assistant_model.generation_config.assistant_lookbehind
@staticmethod
def _get_longest_diag_dict(input_matrix, nonzero_idx):
@@ -450,9 +449,9 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
# Since re-encoding the tokens may result in tokenization discrepancies, we use 2 look behind values
# (one for each conversion) which mark where to start looking for the overlap between the
# source and target encodings, to ensure the new tokens include the correct prompt suffix.
if self.prev_tokens is not None and self.prev_target_ids.shape[1] > self.target_lookbehind:
if self.prev_assistant_ids is not None and input_ids.shape[1] > self.target_lookbehind:
# input_ids contains all target prompt input ids and some new target input ids
start_index_in_target_window = self.prev_target_ids.shape[1] - self.target_lookbehind
start_index_in_target_window = input_ids.shape[1] - self.target_lookbehind
new_assistant_ids = self.convert_source_tokens_to_target_tokens(
input_ids[:, start_index_in_target_window:], **convert_kwargs
@@ -485,7 +484,6 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
else:
assistant_input_ids = self.convert_source_tokens_to_target_tokens(input_ids, **convert_kwargs)
self.prev_target_ids = input_ids
self.prev_assistant_ids = assistant_input_ids
new_cur_len = assistant_input_ids.shape[-1]
@@ -520,6 +518,8 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
num_prev_assistant = self.prev_assistant_ids.shape[1]
start_assistant_look_index = num_prev_assistant - self.assistant_lookbehind
if start_assistant_look_index < 0:
start_assistant_look_index = 0
new_target_ids_from_window = self.convert_source_tokens_to_target_tokens(
assistant_output.sequences[:, start_assistant_look_index:],
@@ -543,14 +543,11 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
# edge case: in case of no intersection between prompt and new_target_ids
new_target_ids = torch.cat([new_target_ids, new_target_ids_from_window], dim=-1)
self.prev_target_ids = input_ids
if hasattr(self.generation_config, "max_length"):
new_target_ids = new_target_ids[:, : self.generation_config.max_length]
# 3. Update variables for the next round of candidate generation
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
self.prev_tokens = assistant_output.sequences
# 4. Prepare variables for output
if input_ids.shape[1] >= new_target_ids.shape[1]:

View File

@@ -360,6 +360,14 @@ class GenerationConfig(PushToHubMixin):
assistant_early_exit(`int`, *optional*):
If set to a positive integer, early exit of the model will be used as an assistant. Can only be used with
models that support early exit (i.e. models where logits from intermediate layers can be interpreted by the LM head).
assistant_lookbehind(`int`, *optional*, defaults to 10):
If set to a positive integer, the re-encodeing process will additionally consider the last `assistant_lookbehind` assistant tokens
to correctly align tokens. Can only be used with different tokenizers in speculative decoding.
See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details.
target_lookbehind(`int`, *optional*, defaults to 10):
If set to a positive integer, the re-encodeing process will additionally consider the last `target_lookbehind` target tokens
to correctly align tokens. Can only be used with different tokenizers in speculative decoding.
See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details.
> Wild card
@@ -460,6 +468,9 @@ class GenerationConfig(PushToHubMixin):
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None)
self.assistant_early_exit = kwargs.pop("assistant_early_exit", None)
## assistant generation for different tokenizers, the windows size for assistant/target model
self.assistant_lookbehind = kwargs.pop("assistant_lookbehind", 10)
self.target_lookbehind = kwargs.pop("target_lookbehind", 10)
# Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {})

View File

@@ -3812,7 +3812,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
do_sample=False,
max_new_tokens=max_new_tokens_item,
assistant_model=draft_model,
target_tokenizer=target_tokenizer,
tokenizer=target_tokenizer,
assistant_tokenizer=assistant_tokenizer,
)