make test_eager_matches_sdpa_inference less flaky (#34512)
* try * try * try * try * try * try * update * update * update * update * update * update * update --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -1263,6 +1263,9 @@ class GenerationTesterMixin:
|
||||
|
||||
if model.get_output_embeddings() is None:
|
||||
self.skipTest("DoLa is not supported for models that don't have output embeddings")
|
||||
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config)
|
||||
|
||||
# Sets dola generation arguments such that:
|
||||
# a) no EOS is generated, to ensure generation doesn't break early
|
||||
# b) there are at least two forward passes in the main model, to ensure the input preparation of
|
||||
@@ -1280,7 +1283,7 @@ class GenerationTesterMixin:
|
||||
"use_cache": getattr(config, "use_cache", False), # Some models don't support the cache
|
||||
"dola_layers": "low",
|
||||
}
|
||||
output_dola = model.generate(**generation_kwargs, **inputs_dict)
|
||||
output_dola = model.generate(**generation_kwargs, **logits_processor_kwargs, **inputs_dict)
|
||||
self._check_outputs(output_dola, model.config, use_cache=getattr(config, "use_cache", False))
|
||||
|
||||
@pytest.mark.generate
|
||||
|
||||
Reference in New Issue
Block a user