Fix flaky test_assisted_decoding_matches_greedy_search (#35951)
fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -1185,7 +1185,9 @@ class GenerationTesterMixin:
|
|||||||
"return_dict_in_generate": True,
|
"return_dict_in_generate": True,
|
||||||
"use_cache": True,
|
"use_cache": True,
|
||||||
}
|
}
|
||||||
output_greedy = model.generate(**generation_kwargs, **inputs_dict)
|
logits_processor_kwargs = self._get_logits_processor_kwargs(config=model.config)
|
||||||
|
|
||||||
|
output_greedy = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs)
|
||||||
|
|
||||||
# test with the same assistant model or randomly init one
|
# test with the same assistant model or randomly init one
|
||||||
# in the first case all candidate tokens are accepted, in the second none is accepted
|
# in the first case all candidate tokens are accepted, in the second none is accepted
|
||||||
@@ -1197,7 +1199,7 @@ class GenerationTesterMixin:
|
|||||||
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
|
assistant_model.generation_config.num_assistant_tokens = 2 # see b)
|
||||||
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
|
assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b)
|
||||||
generation_kwargs.update({"assistant_model": assistant_model})
|
generation_kwargs.update({"assistant_model": assistant_model})
|
||||||
output_assisted = model.generate(**generation_kwargs, **inputs_dict)
|
output_assisted = model.generate(**generation_kwargs, **inputs_dict, **logits_processor_kwargs)
|
||||||
|
|
||||||
# The two outputs must match and their shape must be as expected
|
# The two outputs must match and their shape must be as expected
|
||||||
self._check_similar_generate_outputs(output_greedy, output_assisted)
|
self._check_similar_generate_outputs(output_greedy, output_assisted)
|
||||||
|
|||||||
Reference in New Issue
Block a user