committed by
GitHub
parent
295a90cb40
commit
5ee52ae0bc
@@ -138,14 +138,6 @@ class MllamaForCausalLMModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
||||
def test_eager_matches_sdpa_generate(self):
|
||||
super().test_eager_matches_sdpa_generate()
|
||||
|
||||
@unittest.skip(reason="The outputs don't match, no idea why")
|
||||
def test_beam_search_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Quanto test is borken")
|
||||
def test_generate_with_quant_cache(self):
|
||||
pass
|
||||
|
||||
|
||||
class MllamaVisionText2TextModelTester:
|
||||
def __init__(
|
||||
@@ -208,6 +200,7 @@ class MllamaVisionText2TextModelTester:
|
||||
self.image_size = 224
|
||||
self.max_num_images = 1
|
||||
self.max_image_tiles = 4
|
||||
self.image_length = 904
|
||||
|
||||
def get_config(self):
|
||||
return MllamaConfig(
|
||||
@@ -329,6 +322,43 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
||||
):
|
||||
# Mllama has cross attention layers and those have a different shape than normal attention layers
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
||||
)
|
||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
||||
|
||||
cross_attention_layers = self.model_tester.text_config["cross_attention_layers"]
|
||||
|
||||
for idx, iter_attentions in enumerate(attentions):
|
||||
tgt_len = min_length + idx if not use_cache else 1
|
||||
src_len = min_length + idx
|
||||
|
||||
expected_shape = (
|
||||
batch_size * num_beam_groups,
|
||||
config.num_attention_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
)
|
||||
|
||||
expected_shape_cross = (
|
||||
batch_size * num_beam_groups,
|
||||
config.num_attention_heads,
|
||||
tgt_len,
|
||||
self.model_tester.image_length,
|
||||
)
|
||||
|
||||
expected_shapes = [
|
||||
expected_shape if layer_idx not in cross_attention_layers else expected_shape_cross
|
||||
for layer_idx in range(len(iter_attentions))
|
||||
]
|
||||
|
||||
self.assertListEqual([layer_attention.shape for layer_attention in iter_attentions], expected_shapes)
|
||||
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
@is_flaky()
|
||||
@@ -342,94 +372,14 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
|
||||
# A workaround to override parametrized test with flaky decorator
|
||||
super().test_eager_matches_sdpa_inference_1_bfloat16()
|
||||
|
||||
@unittest.skip(reason="Static cache not supported")
|
||||
def test_static_cache_matches_dynamic(self):
|
||||
# TypeError: list indices must be integers or slices, not tuple
|
||||
# TODO: @raushan, please look into this for new cache format
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Mllama has dynamic control flow which is not yet supported by compile")
|
||||
def test_generate_compile_fullgraph(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="The outputs don't match, no idea why")
|
||||
def test_beam_search_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Mllama is not yet supported by compile")
|
||||
@unittest.skip("For some unknown reasons the tests fails in CrossAttention layer when doing torch.sdpa(). ")
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
# TODO: look into this, AttributeError("'tensor' object has no attribute '__pow__'")
|
||||
# relevant issue: https://github.com/pytorch/pytorch/issues/133166
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="The test itself is broken") # TODO @zucchini-nlp
|
||||
def test_generate_with_quant_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="AssertionError: Items in the second set but not the first: might be a setting issue")
|
||||
def test_model_parallelism(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_compile_cuda_graph_time(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_torch_compile_fullgraph(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Device side assert triggered")
|
||||
def test_assisted_decoding_with_num_logits_to_keep(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_beam_sample_generate_dict_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_beam_search_generate_dict_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_constrained_beam_search_generate_dict_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_dola_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_generate_methods_with_num_logits_to_keep(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_greedy_generate_dict_outputs(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_group_beam_search_generate_dict_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_model_parallel_beam_search(self):
|
||||
pass
|
||||
|
||||
@is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization)
|
||||
def test_new_cache_format_0(self):
|
||||
super().test_new_cache_format_0()
|
||||
|
||||
@is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization)
|
||||
def test_new_cache_format_1(self):
|
||||
super().test_new_cache_format_1()
|
||||
|
||||
@is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization)
|
||||
def test_new_cache_format_2(self):
|
||||
super().test_new_cache_format_2()
|
||||
|
||||
@unittest.skip(reason="Failing test, need to fix")
|
||||
def test_sample_generate_dict_output(self):
|
||||
pass
|
||||
|
||||
def test_generate_text_only_with_cache(self):
|
||||
"""
|
||||
Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature
|
||||
|
||||
Reference in New Issue
Block a user