fix failing test_sdpa_can_dispatch_on_flash (#39259)
* fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -311,7 +311,7 @@ class T5GemmaModelTester:
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
labels=lm_labels,
|
||||
)
|
||||
self.parent.assertEqual(len(outputs), 4)
|
||||
self.parent.assertEqual(len(outputs), 5)
|
||||
self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.seq_length, self.vocab_size))
|
||||
self.parent.assertEqual(outputs["loss"].size(), ())
|
||||
|
||||
@@ -1067,7 +1067,7 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
|
||||
for i in range(num_decoder_layers):
|
||||
if is_legacy_cache:
|
||||
self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple
|
||||
self.assertEqual(len(past_kv[0]), 5) # legacy check: confirm number of elements in tuple
|
||||
|
||||
# Self attention
|
||||
self_attention_layer_key_cache = (
|
||||
@@ -1687,7 +1687,7 @@ class TestAsymmetricT5Gemma(unittest.TestCase):
|
||||
labels=lm_labels,
|
||||
)
|
||||
# outputs = model(*inputs)
|
||||
assert len(outputs) == 4
|
||||
assert len(outputs) == 5
|
||||
assert outputs["logits"].size() == (tester.batch_size, tester.seq_length, tester.vocab_size)
|
||||
assert outputs["loss"].size() == ()
|
||||
return model.model
|
||||
|
||||
Reference in New Issue
Block a user