🚨🚨 Setting default behavior of assisted decoding (#33657)

This commit is contained in:
Jonathan Mamou
2024-09-25 11:39:09 +03:00
committed by GitHub
parent 5f0c181f4e
commit 52daf4ec76
4 changed files with 24 additions and 8 deletions

View File

@@ -159,6 +159,8 @@ class AssistedCandidateGenerator(CandidateGenerator):
self.generation_config.return_dict_in_generate = True self.generation_config.return_dict_in_generate = True
self.generation_config.output_scores = True self.generation_config.output_scores = True
self.generation_config.assistant_confidence_threshold = self.assistant_confidence_threshold self.generation_config.assistant_confidence_threshold = self.assistant_confidence_threshold
# this flag allow us set the confidence stopping criteria for assistant model generation.
self.generation_config.is_assistant = True
# avoid unnecessary warnings that min_length is larger than max_new_tokens # avoid unnecessary warnings that min_length is larger than max_new_tokens
# remove the `MinLengthLogitsProcessor` if exists (NOTE: no need to check for `MinNewTokensLogitsProcessor`) # remove the `MinLengthLogitsProcessor` if exists (NOTE: no need to check for `MinNewTokensLogitsProcessor`)

View File

@@ -338,19 +338,20 @@ class GenerationConfig(PushToHubMixin):
(e.g. multilingual models with different target languages in one batch) (e.g. multilingual models with different target languages in one batch)
> Generation parameters exclusive to assistant generation > Generation parameters exclusive to assistant generation
is_assistant (`bool`, *optional*, defaults to `False`):
num_assistant_tokens (`int`, *optional*, defaults to 5): Whether the model is an assistant (draft) model.
num_assistant_tokens (`int`, *optional*, defaults to 20):
Defines the number of _speculative tokens_ that shall be generated by the assistant model before being Defines the number of _speculative tokens_ that shall be generated by the assistant model before being
checked by the target model at each iteration. Higher values for `num_assistant_tokens` make the generation checked by the target model at each iteration. Higher values for `num_assistant_tokens` make the generation
more _speculative_ : If the assistant model is performant larger speed-ups can be reached, if the assistant more _speculative_ : If the assistant model is performant larger speed-ups can be reached, if the assistant
model requires lots of corrections, lower speed-ups are reached. model requires lots of corrections, lower speed-ups are reached.
num_assistant_tokens_schedule (`str`, *optional*, defaults to `"heuristic"`): num_assistant_tokens_schedule (`str`, *optional*, defaults to `"constant"`):
Defines the schedule at which max assistant tokens shall be changed during inference. Defines the schedule at which max assistant tokens shall be changed during inference.
- `"heuristic"`: When all speculative tokens are correct, increase `num_assistant_tokens` by 2 else - `"heuristic"`: When all speculative tokens are correct, increase `num_assistant_tokens` by 2 else
reduce by 1. `num_assistant_tokens` value is persistent over multiple generation calls with the same assistant model. reduce by 1. `num_assistant_tokens` value is persistent over multiple generation calls with the same assistant model.
- `"heuristic_transient"`: Same as `"heuristic"` but `num_assistant_tokens` is reset to its initial value after each generation call. - `"heuristic_transient"`: Same as `"heuristic"` but `num_assistant_tokens` is reset to its initial value after each generation call.
- `"constant"`: `num_assistant_tokens` stays unchanged during generation - `"constant"`: `num_assistant_tokens` stays unchanged during generation
assistant_confidence_threshold (`float`, *optional*): assistant_confidence_threshold (`float`, *optional*, defaults to 0.4):
The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower
than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_ than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_
(defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead (defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead
@@ -452,9 +453,10 @@ class GenerationConfig(PushToHubMixin):
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
# Assistant generation # Assistant generation
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5) self.is_assistant = False
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic") self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 20)
self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", None) self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "constant")
self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", 0.4)
# Prompt lookup decoding # Prompt lookup decoding
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None) self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)

View File

@@ -953,7 +953,8 @@ class GenerationMixin:
if generation_config._eos_token_tensor is not None: if generation_config._eos_token_tensor is not None:
criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor)) criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor))
if ( if (
generation_config.assistant_confidence_threshold is not None generation_config.is_assistant
and generation_config.assistant_confidence_threshold is not None
and generation_config.assistant_confidence_threshold > 0 and generation_config.assistant_confidence_threshold > 0
): ):
criteria.append( criteria.append(

View File

@@ -2069,6 +2069,7 @@ class GenerationTesterMixin:
"assistant_model": assistant_model, "assistant_model": assistant_model,
} }
assistant_model.generation_config.assistant_confidence_threshold = None
# Setting num_logits_to_keep at 0 keeps all logits (old behavior) # Setting num_logits_to_keep at 0 keeps all logits (old behavior)
with_all_logits = model.generate( with_all_logits = model.generate(
input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict, num_logits_to_keep=0 input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict, num_logits_to_keep=0
@@ -3098,6 +3099,16 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
) )
self.assertEqual(len(warning_list), 0) self.assertEqual(len(warning_list), 0)
def test_default_assisted_generation(self):
# Initialize the GenerationConfig object
config = GenerationConfig()
# Check the default values
self.assertEqual(config.num_assistant_tokens, 20)
self.assertEqual(config.num_assistant_tokens_schedule, "constant")
self.assertEqual(config.assistant_confidence_threshold, 0.4)
self.assertEqual(config.is_assistant, False)
def test_generated_length_assisted_generation(self): def test_generated_length_assisted_generation(self):
# PT-only test: TF doesn't support assisted decoding yet. # PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)