Mamba: add generative tests (#31478)

This commit is contained in:
Joao Gante
2024-06-19 10:27:23 +01:00
committed by GitHub
parent 7d683f7bae
commit 83259e406d
8 changed files with 83 additions and 56 deletions

View File

@@ -503,10 +503,6 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# They should result in very similar logits
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3))
@unittest.skip("Jamba has its own special cache type") # FIXME: @gante
def test_assisted_decoding_matches_greedy_search_0_random(self):
pass
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes