@@ -291,10 +291,6 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Mixtral uses GQA on all models so the KV cache is a non standard format")
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
|
||||
Reference in New Issue
Block a user