fix num_assistant_tokens with heuristic schedule (#28759)
* fix heuristic num_assistant_tokens_schedule * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update utils.py check that candidate_generator.assistant_model exists since some some speculations (like ngram and PLD) don't have assistant_model attribute * Update src/transformers/generation/candidate_generator.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update tests/generation/test_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * make fixup * merge conflict * fix docstring * make fixup --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -225,7 +225,10 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
|||||||
# Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
|
# Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
|
||||||
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
|
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
|
||||||
# cost of forecasting incorrect assistant tokens.
|
# cost of forecasting incorrect assistant tokens.
|
||||||
if self.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic":
|
if self.assistant_model.generation_config.num_assistant_tokens_schedule in {
|
||||||
|
"heuristic",
|
||||||
|
"heuristic_transient",
|
||||||
|
}:
|
||||||
if num_matches == int(self.num_assistant_tokens):
|
if num_matches == int(self.num_assistant_tokens):
|
||||||
self.num_assistant_tokens += 2.0
|
self.num_assistant_tokens += 2.0
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -249,8 +249,9 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
|
|
||||||
num_assistant_tokens_schedule (`str`, *optional*, defaults to `"heuristic"`):
|
num_assistant_tokens_schedule (`str`, *optional*, defaults to `"heuristic"`):
|
||||||
Defines the schedule at which max assistant tokens shall be changed during inference.
|
Defines the schedule at which max assistant tokens shall be changed during inference.
|
||||||
- `"_heuristic_`: When all _speculative_ tokens are correct, increase `num_assistant_tokens` by 2 else
|
- `"heuristic"`: When all speculative tokens are correct, increase `num_assistant_tokens` by 2 else
|
||||||
reduce by 1
|
reduce by 1. `num_assistant_tokens` value is persistent over multiple generation calls with the same assistant model.
|
||||||
|
- `"heuristic_transient"`: Same as `"heuristic"` but `num_assistant_tokens` is reset to its initial value after each generation call.
|
||||||
- `"constant"`: `num_assistant_tokens` stays unchanged during generation
|
- `"constant"`: `num_assistant_tokens` stays unchanged during generation
|
||||||
|
|
||||||
> Parameters specific to the caching mechanism:
|
> Parameters specific to the caching mechanism:
|
||||||
|
|||||||
@@ -4561,6 +4561,13 @@ class GenerationMixin:
|
|||||||
if streamer is not None:
|
if streamer is not None:
|
||||||
streamer.end()
|
streamer.end()
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(candidate_generator, "assistant_model")
|
||||||
|
and candidate_generator.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic"
|
||||||
|
):
|
||||||
|
candidate_generator.assistant_model.generation_config.num_assistant_tokens = (
|
||||||
|
candidate_generator.num_assistant_tokens
|
||||||
|
)
|
||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
return GenerateEncoderDecoderOutput(
|
return GenerateEncoderDecoderOutput(
|
||||||
|
|||||||
@@ -3490,3 +3490,49 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
)
|
)
|
||||||
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||||
|
|
||||||
|
def test_assisted_decoding_num_assistant_tokens_heuristic_schedule(self):
|
||||||
|
# This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly.
|
||||||
|
|
||||||
|
prompt = "Alice and Bob"
|
||||||
|
checkpoint = "EleutherAI/pythia-160m-deduped"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(checkpoint)
|
||||||
|
|
||||||
|
assistant_model = model
|
||||||
|
assistant_model.generation_config.num_assistant_tokens = 5
|
||||||
|
assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic"
|
||||||
|
generation_kwargs = {
|
||||||
|
"eos_token_id": -1,
|
||||||
|
"max_new_tokens": 5,
|
||||||
|
"do_sample": False,
|
||||||
|
"assistant_model": assistant_model,
|
||||||
|
}
|
||||||
|
model.generate(**inputs, **generation_kwargs)
|
||||||
|
# update_candidate_strategy is called only once and therefore, assistant_model.generation_config.num_assistant_tokens should be either 4 or 7
|
||||||
|
self.assertTrue(assistant_model.generation_config.num_assistant_tokens in (4, 7))
|
||||||
|
|
||||||
|
def test_assisted_decoding_num_assistant_tokens_heuristic_transient_schedule(self):
|
||||||
|
# This test ensures that the assisted generation num_assistant_tokens 'heuristic' schedule works properly.
|
||||||
|
|
||||||
|
prompt = "Alice and Bob"
|
||||||
|
checkpoint = "EleutherAI/pythia-160m-deduped"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(checkpoint)
|
||||||
|
|
||||||
|
assistant_model = model
|
||||||
|
assistant_model.generation_config.num_assistant_tokens = 5
|
||||||
|
assistant_model.generation_config.num_assistant_tokens_schedule = "heuristic_transient"
|
||||||
|
generation_kwargs = {
|
||||||
|
"eos_token_id": -1,
|
||||||
|
"max_new_tokens": 5,
|
||||||
|
"do_sample": False,
|
||||||
|
"assistant_model": assistant_model,
|
||||||
|
}
|
||||||
|
model.generate(**inputs, **generation_kwargs)
|
||||||
|
# update_candidate_strategy is called once but assistant_model.generation_config.num_assistant_tokens should stay 5
|
||||||
|
self.assertEqual(assistant_model.generation_config.num_assistant_tokens, 5)
|
||||||
|
|||||||
Reference in New Issue
Block a user