make test_eager_matches_sdpa_inference less flaky (#34512)

* try

* try

* try

* try

* try

* try

* update

* update

* update

* update

* update

* update

* update

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2024-10-31 18:34:00 +01:00
committed by GitHub
parent 294c170ff9
commit 114dd812dd
4 changed files with 45 additions and 54 deletions

View File

@@ -1263,6 +1263,9 @@ class GenerationTesterMixin:
if model.get_output_embeddings() is None:
self.skipTest("DoLa is not supported for models that don't have output embeddings")
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config)
# Sets dola generation arguments such that:
# a) no EOS is generated, to ensure generation doesn't break early
# b) there are at least two forward passes in the main model, to ensure the input preparation of
@@ -1280,7 +1283,7 @@ class GenerationTesterMixin:
"use_cache": getattr(config, "use_cache", False), # Some models don't support the cache
"dola_layers": "low",
}
output_dola = model.generate(**generation_kwargs, **inputs_dict)
output_dola = model.generate(**generation_kwargs, **logits_processor_kwargs, **inputs_dict)
self._check_outputs(output_dola, model.config, use_cache=getattr(config, "use_cache", False))
@pytest.mark.generate