From 52daf4ec768fb9ffe84a0c373834172a7c54aecc Mon Sep 17 00:00:00 2001 From: Jonathan Mamou Date: Wed, 25 Sep 2024 11:39:09 +0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=F0=9F=9A=A8=20Setting=20default=20?= =?UTF-8?q?behavior=20of=20assisted=20decoding=20(#33657)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../generation/candidate_generator.py | 2 ++ .../generation/configuration_utils.py | 16 +++++++++------- src/transformers/generation/utils.py | 3 ++- tests/generation/test_utils.py | 11 +++++++++++ 4 files changed, 24 insertions(+), 8 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 0b799dceb2..fb3120b3ce 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -159,6 +159,8 @@ class AssistedCandidateGenerator(CandidateGenerator): self.generation_config.return_dict_in_generate = True self.generation_config.output_scores = True self.generation_config.assistant_confidence_threshold = self.assistant_confidence_threshold + # this flag allow us set the confidence stopping criteria for assistant model generation. + self.generation_config.is_assistant = True # avoid unnecessary warnings that min_length is larger than max_new_tokens # remove the `MinLengthLogitsProcessor` if exists (NOTE: no need to check for `MinNewTokensLogitsProcessor`) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 5e9ac835c1..60e9323dcb 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -338,19 +338,20 @@ class GenerationConfig(PushToHubMixin): (e.g. multilingual models with different target languages in one batch) > Generation parameters exclusive to assistant generation - - num_assistant_tokens (`int`, *optional*, defaults to 5): + is_assistant (`bool`, *optional*, defaults to `False`): + Whether the model is an assistant (draft) model. + num_assistant_tokens (`int`, *optional*, defaults to 20): Defines the number of _speculative tokens_ that shall be generated by the assistant model before being checked by the target model at each iteration. Higher values for `num_assistant_tokens` make the generation more _speculative_ : If the assistant model is performant larger speed-ups can be reached, if the assistant model requires lots of corrections, lower speed-ups are reached. - num_assistant_tokens_schedule (`str`, *optional*, defaults to `"heuristic"`): + num_assistant_tokens_schedule (`str`, *optional*, defaults to `"constant"`): Defines the schedule at which max assistant tokens shall be changed during inference. - `"heuristic"`: When all speculative tokens are correct, increase `num_assistant_tokens` by 2 else reduce by 1. `num_assistant_tokens` value is persistent over multiple generation calls with the same assistant model. - `"heuristic_transient"`: Same as `"heuristic"` but `num_assistant_tokens` is reset to its initial value after each generation call. - `"constant"`: `num_assistant_tokens` stays unchanged during generation - assistant_confidence_threshold (`float`, *optional*): + assistant_confidence_threshold (`float`, *optional*, defaults to 0.4): The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_ (defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead @@ -452,9 +453,10 @@ class GenerationConfig(PushToHubMixin): self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) # Assistant generation - self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5) - self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic") - self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", None) + self.is_assistant = False + self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 20) + self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "constant") + self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", 0.4) # Prompt lookup decoding self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c1aa338a7d..aedd1674df 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -953,7 +953,8 @@ class GenerationMixin: if generation_config._eos_token_tensor is not None: criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor)) if ( - generation_config.assistant_confidence_threshold is not None + generation_config.is_assistant + and generation_config.assistant_confidence_threshold is not None and generation_config.assistant_confidence_threshold > 0 ): criteria.append( diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 9754a4b7dc..3f8e99a334 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2069,6 +2069,7 @@ class GenerationTesterMixin: "assistant_model": assistant_model, } + assistant_model.generation_config.assistant_confidence_threshold = None # Setting num_logits_to_keep at 0 keeps all logits (old behavior) with_all_logits = model.generate( input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict, num_logits_to_keep=0 @@ -3098,6 +3099,16 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ) self.assertEqual(len(warning_list), 0) + def test_default_assisted_generation(self): + # Initialize the GenerationConfig object + config = GenerationConfig() + + # Check the default values + self.assertEqual(config.num_assistant_tokens, 20) + self.assertEqual(config.num_assistant_tokens_schedule, "constant") + self.assertEqual(config.assistant_confidence_threshold, 0.4) + self.assertEqual(config.is_assistant, False) + def test_generated_length_assisted_generation(self): # PT-only test: TF doesn't support assisted decoding yet. model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)