Mamba: add generative tests (#31478)

This commit is contained in:
Joao Gante
2024-06-19 10:27:23 +01:00
committed by GitHub
parent 7d683f7bae
commit 83259e406d
8 changed files with 83 additions and 56 deletions

View File

@@ -503,10 +503,6 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# They should result in very similar logits
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3))
@unittest.skip("Jamba has its own special cache type") # FIXME: @gante
def test_assisted_decoding_matches_greedy_search_0_random(self):
pass
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes

View File

@@ -250,6 +250,8 @@ class MambaModelTester:
@require_torch
class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (MambaModel, MambaForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (MambaForCausalLM,) if is_torch_available() else ()
has_attentions = False # Mamba does not support attentions
fx_compatible = False # FIXME let's try to support this @ArthurZucker
test_torchscript = False # FIXME let's try to support this @ArthurZucker
test_missing_keys = False
@@ -292,10 +294,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip("No attention in mamba")
def test_retain_grad_hidden_states_attentions(self):
pass
@require_torch_multi_gpu
def test_multi_gpu_data_parallel_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -364,14 +362,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# check if it's a ones like
self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5))
@unittest.skip("Mamba does not use attention")
def test_attention_outputs(self):
r"""
Overriding the test_attention_outputs test as the attention outputs of Mamba are different from other models
it has a shape `batch_size, seq_len, hidden_size`.
"""
pass
@slow
def test_model_from_pretrained(self):
model = MambaModel.from_pretrained("hf-internal-testing/mamba-130m")