From 24f771a043871a109d8a969bf92730746e085f3b Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 11 Jul 2025 16:30:56 +0200 Subject: [PATCH] fix failing `test_sdpa_can_dispatch_on_flash` (#39259) * fix * fix * fix --------- Co-authored-by: ydshieh --- src/transformers/utils/generic.py | 15 +++++++++++---- tests/models/t5gemma/test_modeling_t5gemma.py | 6 +++--- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 8b6afb72ed..9692f705f2 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -993,16 +993,23 @@ def check_model_inputs(func): @wraps(func) def wrapper(self, *args, **kwargs): - use_cache = kwargs.get("use_cache", getattr(self.config, "use_cache", False)) - return_dict = kwargs.pop("return_dict", getattr(self.config, "return_dict", True)) - all_args = kwargs.copy() + use_cache = kwargs.get("use_cache", None) + if use_cache is None: + use_cache = getattr(self.config, "use_cache", False) + + return_dict = kwargs.pop("return_dict", None) + if return_dict is None: + return_dict = getattr(self.config, "return_dict", True) if getattr(self, "gradient_checkpointing", False) and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) - kwargs["use_cache"] = False + use_cache = False + kwargs["use_cache"] = use_cache + + all_args = kwargs.copy() if "kwargs" in all_args: for k, v in all_args["kwargs"].items(): all_args[k] = v diff --git a/tests/models/t5gemma/test_modeling_t5gemma.py b/tests/models/t5gemma/test_modeling_t5gemma.py index a9835aee71..0020c5c78e 100644 --- a/tests/models/t5gemma/test_modeling_t5gemma.py +++ b/tests/models/t5gemma/test_modeling_t5gemma.py @@ -311,7 +311,7 @@ class T5GemmaModelTester: decoder_attention_mask=decoder_attention_mask, labels=lm_labels, ) - self.parent.assertEqual(len(outputs), 4) + self.parent.assertEqual(len(outputs), 5) self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(outputs["loss"].size(), ()) @@ -1067,7 +1067,7 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi for i in range(num_decoder_layers): if is_legacy_cache: - self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple + self.assertEqual(len(past_kv[0]), 5) # legacy check: confirm number of elements in tuple # Self attention self_attention_layer_key_cache = ( @@ -1687,7 +1687,7 @@ class TestAsymmetricT5Gemma(unittest.TestCase): labels=lm_labels, ) # outputs = model(*inputs) - assert len(outputs) == 4 + assert len(outputs) == 5 assert outputs["logits"].size() == (tester.batch_size, tester.seq_length, tester.vocab_size) assert outputs["loss"].size() == () return model.model