Refine the code of Universal Assisted Generation (#34823)
* removed the useless attritbutes * add configs for window size * fixed the wrong kwargs * added docstring
This commit is contained in:
@@ -310,10 +310,9 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
|
||||
|
||||
self.target_tokenizer = target_tokenizer
|
||||
self.assistant_tokenizer = assistant_tokenizer
|
||||
self.prev_tokens = None
|
||||
self.prev_assistant_ids = None
|
||||
self.target_lookbehind = 10
|
||||
self.assistant_lookbehind = 10
|
||||
self.target_lookbehind = assistant_model.generation_config.target_lookbehind
|
||||
self.assistant_lookbehind = assistant_model.generation_config.assistant_lookbehind
|
||||
|
||||
@staticmethod
|
||||
def _get_longest_diag_dict(input_matrix, nonzero_idx):
|
||||
@@ -450,9 +449,9 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
|
||||
# Since re-encoding the tokens may result in tokenization discrepancies, we use 2 look behind values
|
||||
# (one for each conversion) which mark where to start looking for the overlap between the
|
||||
# source and target encodings, to ensure the new tokens include the correct prompt suffix.
|
||||
if self.prev_tokens is not None and self.prev_target_ids.shape[1] > self.target_lookbehind:
|
||||
if self.prev_assistant_ids is not None and input_ids.shape[1] > self.target_lookbehind:
|
||||
# input_ids contains all target prompt input ids and some new target input ids
|
||||
start_index_in_target_window = self.prev_target_ids.shape[1] - self.target_lookbehind
|
||||
start_index_in_target_window = input_ids.shape[1] - self.target_lookbehind
|
||||
|
||||
new_assistant_ids = self.convert_source_tokens_to_target_tokens(
|
||||
input_ids[:, start_index_in_target_window:], **convert_kwargs
|
||||
@@ -485,7 +484,6 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
|
||||
|
||||
else:
|
||||
assistant_input_ids = self.convert_source_tokens_to_target_tokens(input_ids, **convert_kwargs)
|
||||
self.prev_target_ids = input_ids
|
||||
|
||||
self.prev_assistant_ids = assistant_input_ids
|
||||
new_cur_len = assistant_input_ids.shape[-1]
|
||||
@@ -520,6 +518,8 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
|
||||
|
||||
num_prev_assistant = self.prev_assistant_ids.shape[1]
|
||||
start_assistant_look_index = num_prev_assistant - self.assistant_lookbehind
|
||||
if start_assistant_look_index < 0:
|
||||
start_assistant_look_index = 0
|
||||
|
||||
new_target_ids_from_window = self.convert_source_tokens_to_target_tokens(
|
||||
assistant_output.sequences[:, start_assistant_look_index:],
|
||||
@@ -543,14 +543,11 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
|
||||
# edge case: in case of no intersection between prompt and new_target_ids
|
||||
new_target_ids = torch.cat([new_target_ids, new_target_ids_from_window], dim=-1)
|
||||
|
||||
self.prev_target_ids = input_ids
|
||||
|
||||
if hasattr(self.generation_config, "max_length"):
|
||||
new_target_ids = new_target_ids[:, : self.generation_config.max_length]
|
||||
|
||||
# 3. Update variables for the next round of candidate generation
|
||||
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
|
||||
self.prev_tokens = assistant_output.sequences
|
||||
|
||||
# 4. Prepare variables for output
|
||||
if input_ids.shape[1] >= new_target_ids.shape[1]:
|
||||
|
||||
@@ -360,6 +360,14 @@ class GenerationConfig(PushToHubMixin):
|
||||
assistant_early_exit(`int`, *optional*):
|
||||
If set to a positive integer, early exit of the model will be used as an assistant. Can only be used with
|
||||
models that support early exit (i.e. models where logits from intermediate layers can be interpreted by the LM head).
|
||||
assistant_lookbehind(`int`, *optional*, defaults to 10):
|
||||
If set to a positive integer, the re-encodeing process will additionally consider the last `assistant_lookbehind` assistant tokens
|
||||
to correctly align tokens. Can only be used with different tokenizers in speculative decoding.
|
||||
See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details.
|
||||
target_lookbehind(`int`, *optional*, defaults to 10):
|
||||
If set to a positive integer, the re-encodeing process will additionally consider the last `target_lookbehind` target tokens
|
||||
to correctly align tokens. Can only be used with different tokenizers in speculative decoding.
|
||||
See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details.
|
||||
|
||||
> Wild card
|
||||
|
||||
@@ -460,6 +468,9 @@ class GenerationConfig(PushToHubMixin):
|
||||
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
|
||||
self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None)
|
||||
self.assistant_early_exit = kwargs.pop("assistant_early_exit", None)
|
||||
## assistant generation for different tokenizers, the windows size for assistant/target model
|
||||
self.assistant_lookbehind = kwargs.pop("assistant_lookbehind", 10)
|
||||
self.target_lookbehind = kwargs.pop("target_lookbehind", 10)
|
||||
|
||||
# Wild card
|
||||
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
|
||||
|
||||
@@ -3812,7 +3812,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
do_sample=False,
|
||||
max_new_tokens=max_new_tokens_item,
|
||||
assistant_model=draft_model,
|
||||
target_tokenizer=target_tokenizer,
|
||||
tokenizer=target_tokenizer,
|
||||
assistant_tokenizer=assistant_tokenizer,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user