Refine the code of Universal Assisted Generation (#34823)
* removed the useless attritbutes * add configs for window size * fixed the wrong kwargs * added docstring
This commit is contained in:
@@ -310,10 +310,9 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
|
|||||||
|
|
||||||
self.target_tokenizer = target_tokenizer
|
self.target_tokenizer = target_tokenizer
|
||||||
self.assistant_tokenizer = assistant_tokenizer
|
self.assistant_tokenizer = assistant_tokenizer
|
||||||
self.prev_tokens = None
|
|
||||||
self.prev_assistant_ids = None
|
self.prev_assistant_ids = None
|
||||||
self.target_lookbehind = 10
|
self.target_lookbehind = assistant_model.generation_config.target_lookbehind
|
||||||
self.assistant_lookbehind = 10
|
self.assistant_lookbehind = assistant_model.generation_config.assistant_lookbehind
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_longest_diag_dict(input_matrix, nonzero_idx):
|
def _get_longest_diag_dict(input_matrix, nonzero_idx):
|
||||||
@@ -450,9 +449,9 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
|
|||||||
# Since re-encoding the tokens may result in tokenization discrepancies, we use 2 look behind values
|
# Since re-encoding the tokens may result in tokenization discrepancies, we use 2 look behind values
|
||||||
# (one for each conversion) which mark where to start looking for the overlap between the
|
# (one for each conversion) which mark where to start looking for the overlap between the
|
||||||
# source and target encodings, to ensure the new tokens include the correct prompt suffix.
|
# source and target encodings, to ensure the new tokens include the correct prompt suffix.
|
||||||
if self.prev_tokens is not None and self.prev_target_ids.shape[1] > self.target_lookbehind:
|
if self.prev_assistant_ids is not None and input_ids.shape[1] > self.target_lookbehind:
|
||||||
# input_ids contains all target prompt input ids and some new target input ids
|
# input_ids contains all target prompt input ids and some new target input ids
|
||||||
start_index_in_target_window = self.prev_target_ids.shape[1] - self.target_lookbehind
|
start_index_in_target_window = input_ids.shape[1] - self.target_lookbehind
|
||||||
|
|
||||||
new_assistant_ids = self.convert_source_tokens_to_target_tokens(
|
new_assistant_ids = self.convert_source_tokens_to_target_tokens(
|
||||||
input_ids[:, start_index_in_target_window:], **convert_kwargs
|
input_ids[:, start_index_in_target_window:], **convert_kwargs
|
||||||
@@ -485,7 +484,6 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
assistant_input_ids = self.convert_source_tokens_to_target_tokens(input_ids, **convert_kwargs)
|
assistant_input_ids = self.convert_source_tokens_to_target_tokens(input_ids, **convert_kwargs)
|
||||||
self.prev_target_ids = input_ids
|
|
||||||
|
|
||||||
self.prev_assistant_ids = assistant_input_ids
|
self.prev_assistant_ids = assistant_input_ids
|
||||||
new_cur_len = assistant_input_ids.shape[-1]
|
new_cur_len = assistant_input_ids.shape[-1]
|
||||||
@@ -520,6 +518,8 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
|
|||||||
|
|
||||||
num_prev_assistant = self.prev_assistant_ids.shape[1]
|
num_prev_assistant = self.prev_assistant_ids.shape[1]
|
||||||
start_assistant_look_index = num_prev_assistant - self.assistant_lookbehind
|
start_assistant_look_index = num_prev_assistant - self.assistant_lookbehind
|
||||||
|
if start_assistant_look_index < 0:
|
||||||
|
start_assistant_look_index = 0
|
||||||
|
|
||||||
new_target_ids_from_window = self.convert_source_tokens_to_target_tokens(
|
new_target_ids_from_window = self.convert_source_tokens_to_target_tokens(
|
||||||
assistant_output.sequences[:, start_assistant_look_index:],
|
assistant_output.sequences[:, start_assistant_look_index:],
|
||||||
@@ -543,14 +543,11 @@ class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
|
|||||||
# edge case: in case of no intersection between prompt and new_target_ids
|
# edge case: in case of no intersection between prompt and new_target_ids
|
||||||
new_target_ids = torch.cat([new_target_ids, new_target_ids_from_window], dim=-1)
|
new_target_ids = torch.cat([new_target_ids, new_target_ids_from_window], dim=-1)
|
||||||
|
|
||||||
self.prev_target_ids = input_ids
|
|
||||||
|
|
||||||
if hasattr(self.generation_config, "max_length"):
|
if hasattr(self.generation_config, "max_length"):
|
||||||
new_target_ids = new_target_ids[:, : self.generation_config.max_length]
|
new_target_ids = new_target_ids[:, : self.generation_config.max_length]
|
||||||
|
|
||||||
# 3. Update variables for the next round of candidate generation
|
# 3. Update variables for the next round of candidate generation
|
||||||
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
|
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
|
||||||
self.prev_tokens = assistant_output.sequences
|
|
||||||
|
|
||||||
# 4. Prepare variables for output
|
# 4. Prepare variables for output
|
||||||
if input_ids.shape[1] >= new_target_ids.shape[1]:
|
if input_ids.shape[1] >= new_target_ids.shape[1]:
|
||||||
|
|||||||
@@ -360,6 +360,14 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
assistant_early_exit(`int`, *optional*):
|
assistant_early_exit(`int`, *optional*):
|
||||||
If set to a positive integer, early exit of the model will be used as an assistant. Can only be used with
|
If set to a positive integer, early exit of the model will be used as an assistant. Can only be used with
|
||||||
models that support early exit (i.e. models where logits from intermediate layers can be interpreted by the LM head).
|
models that support early exit (i.e. models where logits from intermediate layers can be interpreted by the LM head).
|
||||||
|
assistant_lookbehind(`int`, *optional*, defaults to 10):
|
||||||
|
If set to a positive integer, the re-encodeing process will additionally consider the last `assistant_lookbehind` assistant tokens
|
||||||
|
to correctly align tokens. Can only be used with different tokenizers in speculative decoding.
|
||||||
|
See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details.
|
||||||
|
target_lookbehind(`int`, *optional*, defaults to 10):
|
||||||
|
If set to a positive integer, the re-encodeing process will additionally consider the last `target_lookbehind` target tokens
|
||||||
|
to correctly align tokens. Can only be used with different tokenizers in speculative decoding.
|
||||||
|
See this [blog](https://huggingface.co/blog/universal_assisted_generation) for more details.
|
||||||
|
|
||||||
> Wild card
|
> Wild card
|
||||||
|
|
||||||
@@ -460,6 +468,9 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
|
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
|
||||||
self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None)
|
self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None)
|
||||||
self.assistant_early_exit = kwargs.pop("assistant_early_exit", None)
|
self.assistant_early_exit = kwargs.pop("assistant_early_exit", None)
|
||||||
|
## assistant generation for different tokenizers, the windows size for assistant/target model
|
||||||
|
self.assistant_lookbehind = kwargs.pop("assistant_lookbehind", 10)
|
||||||
|
self.target_lookbehind = kwargs.pop("target_lookbehind", 10)
|
||||||
|
|
||||||
# Wild card
|
# Wild card
|
||||||
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
|
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
|
||||||
|
|||||||
@@ -3812,7 +3812,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
do_sample=False,
|
do_sample=False,
|
||||||
max_new_tokens=max_new_tokens_item,
|
max_new_tokens=max_new_tokens_item,
|
||||||
assistant_model=draft_model,
|
assistant_model=draft_model,
|
||||||
target_tokenizer=target_tokenizer,
|
tokenizer=target_tokenizer,
|
||||||
assistant_tokenizer=assistant_tokenizer,
|
assistant_tokenizer=assistant_tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user