Mamba: add generative tests (#31478)
This commit is contained in:
@@ -503,10 +503,6 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
# They should result in very similar logits
|
||||
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3))
|
||||
|
||||
@unittest.skip("Jamba has its own special cache type") # FIXME: @gante
|
||||
def test_assisted_decoding_matches_greedy_search_0_random(self):
|
||||
pass
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@require_bitsandbytes
|
||||
|
||||
@@ -250,6 +250,8 @@ class MambaModelTester:
|
||||
@require_torch
|
||||
class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (MambaModel, MambaForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (MambaForCausalLM,) if is_torch_available() else ()
|
||||
has_attentions = False # Mamba does not support attentions
|
||||
fx_compatible = False # FIXME let's try to support this @ArthurZucker
|
||||
test_torchscript = False # FIXME let's try to support this @ArthurZucker
|
||||
test_missing_keys = False
|
||||
@@ -292,10 +294,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@unittest.skip("No attention in mamba")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -364,14 +362,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
# check if it's a ones like
|
||||
self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5))
|
||||
|
||||
@unittest.skip("Mamba does not use attention")
|
||||
def test_attention_outputs(self):
|
||||
r"""
|
||||
Overriding the test_attention_outputs test as the attention outputs of Mamba are different from other models
|
||||
it has a shape `batch_size, seq_len, hidden_size`.
|
||||
"""
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model = MambaModel.from_pretrained("hf-internal-testing/mamba-130m")
|
||||
|
||||
Reference in New Issue
Block a user