🚨🚨 Setting default behavior of assisted decoding (#33657)
This commit is contained in:
@@ -159,6 +159,8 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
|||||||
self.generation_config.return_dict_in_generate = True
|
self.generation_config.return_dict_in_generate = True
|
||||||
self.generation_config.output_scores = True
|
self.generation_config.output_scores = True
|
||||||
self.generation_config.assistant_confidence_threshold = self.assistant_confidence_threshold
|
self.generation_config.assistant_confidence_threshold = self.assistant_confidence_threshold
|
||||||
|
# this flag allow us set the confidence stopping criteria for assistant model generation.
|
||||||
|
self.generation_config.is_assistant = True
|
||||||
|
|
||||||
# avoid unnecessary warnings that min_length is larger than max_new_tokens
|
# avoid unnecessary warnings that min_length is larger than max_new_tokens
|
||||||
# remove the `MinLengthLogitsProcessor` if exists (NOTE: no need to check for `MinNewTokensLogitsProcessor`)
|
# remove the `MinLengthLogitsProcessor` if exists (NOTE: no need to check for `MinNewTokensLogitsProcessor`)
|
||||||
|
|||||||
@@ -338,19 +338,20 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
(e.g. multilingual models with different target languages in one batch)
|
(e.g. multilingual models with different target languages in one batch)
|
||||||
|
|
||||||
> Generation parameters exclusive to assistant generation
|
> Generation parameters exclusive to assistant generation
|
||||||
|
is_assistant (`bool`, *optional*, defaults to `False`):
|
||||||
num_assistant_tokens (`int`, *optional*, defaults to 5):
|
Whether the model is an assistant (draft) model.
|
||||||
|
num_assistant_tokens (`int`, *optional*, defaults to 20):
|
||||||
Defines the number of _speculative tokens_ that shall be generated by the assistant model before being
|
Defines the number of _speculative tokens_ that shall be generated by the assistant model before being
|
||||||
checked by the target model at each iteration. Higher values for `num_assistant_tokens` make the generation
|
checked by the target model at each iteration. Higher values for `num_assistant_tokens` make the generation
|
||||||
more _speculative_ : If the assistant model is performant larger speed-ups can be reached, if the assistant
|
more _speculative_ : If the assistant model is performant larger speed-ups can be reached, if the assistant
|
||||||
model requires lots of corrections, lower speed-ups are reached.
|
model requires lots of corrections, lower speed-ups are reached.
|
||||||
num_assistant_tokens_schedule (`str`, *optional*, defaults to `"heuristic"`):
|
num_assistant_tokens_schedule (`str`, *optional*, defaults to `"constant"`):
|
||||||
Defines the schedule at which max assistant tokens shall be changed during inference.
|
Defines the schedule at which max assistant tokens shall be changed during inference.
|
||||||
- `"heuristic"`: When all speculative tokens are correct, increase `num_assistant_tokens` by 2 else
|
- `"heuristic"`: When all speculative tokens are correct, increase `num_assistant_tokens` by 2 else
|
||||||
reduce by 1. `num_assistant_tokens` value is persistent over multiple generation calls with the same assistant model.
|
reduce by 1. `num_assistant_tokens` value is persistent over multiple generation calls with the same assistant model.
|
||||||
- `"heuristic_transient"`: Same as `"heuristic"` but `num_assistant_tokens` is reset to its initial value after each generation call.
|
- `"heuristic_transient"`: Same as `"heuristic"` but `num_assistant_tokens` is reset to its initial value after each generation call.
|
||||||
- `"constant"`: `num_assistant_tokens` stays unchanged during generation
|
- `"constant"`: `num_assistant_tokens` stays unchanged during generation
|
||||||
assistant_confidence_threshold (`float`, *optional*):
|
assistant_confidence_threshold (`float`, *optional*, defaults to 0.4):
|
||||||
The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower
|
The confidence threshold for the assistant model. If the assistant model's confidence in its prediction for the current token is lower
|
||||||
than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_
|
than this threshold, the assistant model stops the current token generation iteration, even if the number of _speculative tokens_
|
||||||
(defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead
|
(defined by `num_assistant_tokens`) is not yet reached. It is an unsupervised version of the dynamic speculation lookahead
|
||||||
@@ -452,9 +453,10 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
|
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
|
||||||
|
|
||||||
# Assistant generation
|
# Assistant generation
|
||||||
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
|
self.is_assistant = False
|
||||||
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")
|
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 20)
|
||||||
self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", None)
|
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "constant")
|
||||||
|
self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", 0.4)
|
||||||
|
|
||||||
# Prompt lookup decoding
|
# Prompt lookup decoding
|
||||||
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
|
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
|
||||||
|
|||||||
@@ -953,7 +953,8 @@ class GenerationMixin:
|
|||||||
if generation_config._eos_token_tensor is not None:
|
if generation_config._eos_token_tensor is not None:
|
||||||
criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor))
|
criteria.append(EosTokenCriteria(eos_token_id=generation_config._eos_token_tensor))
|
||||||
if (
|
if (
|
||||||
generation_config.assistant_confidence_threshold is not None
|
generation_config.is_assistant
|
||||||
|
and generation_config.assistant_confidence_threshold is not None
|
||||||
and generation_config.assistant_confidence_threshold > 0
|
and generation_config.assistant_confidence_threshold > 0
|
||||||
):
|
):
|
||||||
criteria.append(
|
criteria.append(
|
||||||
|
|||||||
@@ -2069,6 +2069,7 @@ class GenerationTesterMixin:
|
|||||||
"assistant_model": assistant_model,
|
"assistant_model": assistant_model,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assistant_model.generation_config.assistant_confidence_threshold = None
|
||||||
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
|
# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
|
||||||
with_all_logits = model.generate(
|
with_all_logits = model.generate(
|
||||||
input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict, num_logits_to_keep=0
|
input_ids, attention_mask=attention_mask, **generation_kwargs, **inputs_dict, num_logits_to_keep=0
|
||||||
@@ -3098,6 +3099,16 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
)
|
)
|
||||||
self.assertEqual(len(warning_list), 0)
|
self.assertEqual(len(warning_list), 0)
|
||||||
|
|
||||||
|
def test_default_assisted_generation(self):
|
||||||
|
# Initialize the GenerationConfig object
|
||||||
|
config = GenerationConfig()
|
||||||
|
|
||||||
|
# Check the default values
|
||||||
|
self.assertEqual(config.num_assistant_tokens, 20)
|
||||||
|
self.assertEqual(config.num_assistant_tokens_schedule, "constant")
|
||||||
|
self.assertEqual(config.assistant_confidence_threshold, 0.4)
|
||||||
|
self.assertEqual(config.is_assistant, False)
|
||||||
|
|
||||||
def test_generated_length_assisted_generation(self):
|
def test_generated_length_assisted_generation(self):
|
||||||
# PT-only test: TF doesn't support assisted decoding yet.
|
# PT-only test: TF doesn't support assisted decoding yet.
|
||||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||||
|
|||||||
Reference in New Issue
Block a user