Fix flaky test_assisted_decoding_matches_greedy_search (#35951)
fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -1185,7 +1185,9 @@ class GenerationTesterMixin:
|
||||
"return_dict_in_generate": True,
|
||||
"use_cache": True,
|
||||
}
|
||||
output_greedy = model.generate(**generation_kwargs, **inputs_dict)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config)
|
||||
|
||||
output_greedy = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs)
|
||||
|
||||
# test with the same assistant model or randomly init one
|
||||
# in the first case all candidate tokens are accepted, in the second none is accepted
|
||||
@@ -1197,7 +1199,7 @@ class GenerationTesterMixin:
|
||||
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
|
||||
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
|
||||
generation_kwargs.update({"assistant_model": assistant_model})
|
||||
output_assisted = model.generate(**generation_kwargs, **inputs_dict)
|
||||
output_assisted = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs)
|
||||
|
||||
# The two outputs must match and their shape must be as expected
|
||||
self._check_similar_generate_outputs(output_greedy, output_assisted)
|
||||
|
||||
Reference in New Issue
Block a user