fix failing test_sdpa_can_dispatch_on_flash (#39259)

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2025-07-11 16:30:56 +02:00
committed by GitHub
parent ee74397d20
commit 24f771a043
2 changed files with 14 additions and 7 deletions

View File

@@ -311,7 +311,7 @@ class T5GemmaModelTester:
decoder_attention_mask=decoder_attention_mask,
labels=lm_labels,
)
self.parent.assertEqual(len(outputs), 4)
self.parent.assertEqual(len(outputs), 5)
self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertEqual(outputs["loss"].size(), ())
@@ -1067,7 +1067,7 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
for i in range(num_decoder_layers):
if is_legacy_cache:
self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple
self.assertEqual(len(past_kv[0]), 5) # legacy check: confirm number of elements in tuple
# Self attention
self_attention_layer_key_cache = (
@@ -1687,7 +1687,7 @@ class TestAsymmetricT5Gemma(unittest.TestCase):
labels=lm_labels,
)
# outputs = model(*inputs)
assert len(outputs) == 4
assert len(outputs) == 5
assert outputs["logits"].size() == (tester.batch_size, tester.seq_length, tester.vocab_size)
assert outputs["loss"].size() == ()
return model.model