Fix flaky test_assisted_decoding_matches_greedy_search (#35951)

fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2025-01-29 14:50:07 +01:00
committed by GitHub
parent 692afa102d
commit cf90404807

View File

@@ -1185,7 +1185,9 @@ class GenerationTesterMixin:
"return_dict_in_generate": True,
"use_cache": True,
}
output_greedy = model.generate(**generation_kwargs, **inputs_dict)
logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config)
output_greedy = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs)
# test with the same assistant model or randomly init one
# in the first case all candidate tokens are accepted, in the second none is accepted
@@ -1197,7 +1199,7 @@ class GenerationTesterMixin:
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
generation_kwargs.update({"assistant_model": assistant_model})
output_assisted = model.generate(**generation_kwargs, **inputs_dict)
output_assisted = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs)
# The two outputs must match and their shape must be as expected
self._check_similar_generate_outputs(output_greedy, output_assisted)