fix failing test_sdpa_can_dispatch_on_flash (#39259)

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2025-07-11 16:30:56 +02:00
committed by GitHub
parent ee74397d20
commit 24f771a043
2 changed files with 14 additions and 7 deletions

View File

@@ -993,16 +993,23 @@ def check_model_inputs(func):
@wraps(func) @wraps(func)
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
use_cache = kwargs.get("use_cache", getattr(self.config, "use_cache", False)) use_cache = kwargs.get("use_cache", None)
return_dict = kwargs.pop("return_dict", getattr(self.config, "return_dict", True)) if use_cache is None:
all_args = kwargs.copy() use_cache = getattr(self.config, "use_cache", False)
return_dict = kwargs.pop("return_dict", None)
if return_dict is None:
return_dict = getattr(self.config, "return_dict", True)
if getattr(self, "gradient_checkpointing", False) and self.training and use_cache: if getattr(self, "gradient_checkpointing", False) and self.training and use_cache:
logger.warning_once( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
) )
kwargs["use_cache"] = False use_cache = False
kwargs["use_cache"] = use_cache
all_args = kwargs.copy()
if "kwargs" in all_args: if "kwargs" in all_args:
for k, v in all_args["kwargs"].items(): for k, v in all_args["kwargs"].items():
all_args[k] = v all_args[k] = v

View File

@@ -311,7 +311,7 @@ class T5GemmaModelTester:
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
labels=lm_labels, labels=lm_labels,
) )
self.parent.assertEqual(len(outputs), 4) self.parent.assertEqual(len(outputs), 5)
self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertEqual(outputs["loss"].size(), ()) self.parent.assertEqual(outputs["loss"].size(), ())
@@ -1067,7 +1067,7 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
for i in range(num_decoder_layers): for i in range(num_decoder_layers):
if is_legacy_cache: if is_legacy_cache:
self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple self.assertEqual(len(past_kv[0]), 5) # legacy check: confirm number of elements in tuple
# Self attention # Self attention
self_attention_layer_key_cache = ( self_attention_layer_key_cache = (
@@ -1687,7 +1687,7 @@ class TestAsymmetricT5Gemma(unittest.TestCase):
labels=lm_labels, labels=lm_labels,
) )
# outputs = model(*inputs) # outputs = model(*inputs)
assert len(outputs) == 4 assert len(outputs) == 5
assert outputs["logits"].size() == (tester.batch_size, tester.seq_length, tester.vocab_size) assert outputs["logits"].size() == (tester.batch_size, tester.seq_length, tester.vocab_size)
assert outputs["loss"].size() == () assert outputs["loss"].size() == ()
return model.model return model.model