[test] update test_past_key_values_format (#37614)

allow custom shapes
This commit is contained in:
Joao Gante
2025-04-22 11:07:34 +01:00
committed by GitHub
parent 1cd110c6cb
commit 362fa37da2
19 changed files with 134 additions and 166 deletions

View File

@@ -281,10 +281,6 @@ class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
@unittest.skip(reason="Starcoder2 uses GQA on all models so the KV cache is a non standard format")
def test_past_key_values_format(self):
pass
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test