From cf90404807658b140909b2a2be4865906e7abd09 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 29 Jan 2025 14:50:07 +0100 Subject: [PATCH] Fix flaky `test_assisted_decoding_matches_greedy_search` (#35951) fix Co-authored-by: ydshieh --- tests/generation/test_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 26f3b3e324..3aaa16704e 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1185,7 +1185,9 @@ class GenerationTesterMixin: "return_dict_in_generate": True, "use_cache": True, } - output_greedy = model.generate(**generation_kwargs, **inputs_dict) + logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config) + + output_greedy = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs) # test with the same assistant model or randomly init one # in the first case all candidate tokens are accepted, in the second none is accepted @@ -1197,7 +1199,7 @@ class GenerationTesterMixin: assistant_model.generation_config.num_assistant_tokens = 2 # see b) assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b) generation_kwargs.update({"assistant_model": assistant_model}) - output_assisted = model.generate(**generation_kwargs, **inputs_dict) + output_assisted = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs) # The two outputs must match and their shape must be as expected self._check_similar_generate_outputs(output_greedy, output_assisted)