From 8fc6ecba4f09a738e02d8cb08736ac924f504c08 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 12 Feb 2025 12:55:46 +0100 Subject: [PATCH] VLM: enable skipped tests (#35746) * fix cached tests * fix some tests * fix pix2struct * fix --- .../models/blip_2/modeling_blip_2.py | 6 ++ .../instructblip/modeling_instructblip.py | 6 ++ .../modeling_instructblipvideo.py | 6 ++ .../modular_instructblipvideo.py | 3 + .../models/kosmos2/modeling_kosmos2.py | 13 ++- tests/generation/test_utils.py | 45 +++++--- tests/models/aria/test_modeling_aria.py | 6 +- tests/models/mllama/test_modeling_mllama.py | 100 ++++++++++++++++++ tests/models/moshi/test_modeling_moshi.py | 12 +++ .../paligemma2/test_modeling_paligemma2.py | 39 +++++++ 10 files changed, 216 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 99d678b122..84f0356cec 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -579,6 +579,9 @@ BLIP_2_INPUTS_DOCSTRING = r""" Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): Whether to interpolate the pre-trained position encodings. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). """ BLIP2_IMAGE_TEXT_RETRIEVAL_INPUTS_DOCSTRING = r""" @@ -2094,6 +2097,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False, + use_cache: Optional[bool] = None, ) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]: r""" Returns: @@ -2217,6 +2221,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + use_cache=use_cache, ) logits = outputs.logits if return_dict else outputs[0] loss = None @@ -2242,6 +2247,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): output_hidden_states=output_hidden_states, return_dict=True, # toggle for easier access to loss/logits below labels=labels, + use_cache=use_cache, ) loss = outputs.loss logits = outputs.logits diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 163f978ba2..a04a27b018 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -441,6 +441,9 @@ INSTRUCTBLIP_INPUTS_DOCSTRING = r""" Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): Whether to interpolate the pre-trained position encodings. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). """ @@ -1375,6 +1378,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False, + use_cache: Optional[bool] = None, ) -> Union[Tuple, InstructBlipForConditionalGenerationModelOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1485,6 +1489,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + use_cache=use_cache, ) logits = outputs.logits if return_dict else outputs[0] loss = None @@ -1510,6 +1515,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati output_hidden_states=output_hidden_states, return_dict=return_dict, labels=labels, + use_cache=use_cache, ) loss = outputs.loss if return_dict else outputs[0] logits = outputs.logits if return_dict else outputs[1] diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index e91b05bc01..18aed63920 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -1265,6 +1265,9 @@ INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r""" Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): Whether to interpolate the pre-trained position encodings. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). """ @@ -1369,6 +1372,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False, + use_cache: Optional[bool] = None, ) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1512,6 +1516,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + use_cache=use_cache, ) logits = outputs.logits if return_dict else outputs[0] loss = None @@ -1537,6 +1542,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel output_hidden_states=output_hidden_states, return_dict=return_dict, labels=labels, + use_cache=use_cache, ) loss = outputs.loss if return_dict else outputs[0] logits = outputs.logits if return_dict else outputs[1] diff --git a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py index 1376e85c6f..7409581358 100644 --- a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py @@ -188,6 +188,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, interpolate_pos_encoding: bool = False, + use_cache: Optional[bool] = None, ) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]: r""" ```python @@ -322,6 +323,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + use_cache=use_cache, ) logits = outputs.logits if return_dict else outputs[0] loss = None @@ -347,6 +349,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera output_hidden_states=output_hidden_states, return_dict=return_dict, labels=labels, + use_cache=use_cache, ) loss = outputs.loss if return_dict else outputs[0] logits = outputs.logits if return_dict else outputs[1] diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index e6662656e7..55277cd5a1 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -1694,6 +1694,7 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin): past_key_values=None, attention_mask=None, use_cache=None, + cache_position=None, **model_kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model @@ -1704,17 +1705,21 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin): attention_mask = input_ids.new_ones(input_shape) position_ids = None + if cache_position is None: + past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device) - # cut input_ids if past_key_values is used if past_key_values is not None: position_ids = create_position_ids_from_input_ids( input_ids, padding_idx=self.config.pad_token_id, past_key_values_length=0, - )[:, -1:] + ) + + if input_ids.shape[1] != cache_position.shape[0]: + input_ids = input_ids[:, cache_position] + position_ids = position_ids[:, -input_ids.shape[1] :] - input_ids = input_ids[:, -1:] - # the image info. is already encoded into the past keys/values image_embeds = None image_embeds_position_mask = None elif image_embeds_position_mask is not None: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index dc88091f4c..a953221eb4 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -516,7 +516,7 @@ class GenerationTesterMixin: if self.has_attentions: config._attn_implementation = "eager" # can't output attentions otherwise - if not hasattr(config, "use_cache"): + if not hasattr(config.get_text_config(), "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]): self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes") @@ -651,7 +651,7 @@ class GenerationTesterMixin: for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - if not hasattr(config, "use_cache"): + if not hasattr(config.get_text_config(), "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]): self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes") @@ -989,7 +989,7 @@ class GenerationTesterMixin: config, inputs_dict = self.prepare_config_and_inputs_for_generate() # NOTE: contrastive search only works with cache on at the moment. - if not hasattr(config, "use_cache"): + if not hasattr(config.get_text_config(), "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") config.is_decoder = True @@ -1018,7 +1018,7 @@ class GenerationTesterMixin: config, inputs_dict = self.prepare_config_and_inputs_for_generate() # NOTE: contrastive search only works with cache on at the moment. - if not hasattr(config, "use_cache"): + if not hasattr(config.get_text_config(), "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") config.is_decoder = True if self.has_attentions: @@ -1060,7 +1060,7 @@ class GenerationTesterMixin: config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) # NOTE: contrastive search only works with cache on at the moment. - if not hasattr(config, "use_cache"): + if not hasattr(config.get_text_config(), "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") config.is_decoder = True @@ -1179,6 +1179,10 @@ class GenerationTesterMixin: "prophetnet", "seamlessm4t", "clvp", + "mllama", # special cache sizes + "blip2", # overridden `generate()` + "instructblip", + "instructblipvideo", ] ): self.skipTest(reason="May fix in the future: need model-specific fixes") @@ -1187,7 +1191,7 @@ class GenerationTesterMixin: config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) # NOTE: assisted generation only works with cache on at the moment. - if not hasattr(config, "use_cache"): + if not hasattr(config.get_text_config(), "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") config.is_decoder = True @@ -1254,6 +1258,10 @@ class GenerationTesterMixin: "seamlessm4t", "clvp", "fuyu", + "mllama", # special cache sizes + "blip2", # overridden `generate()` + "instructblip", + "instructblipvideo", ] ): self.skipTest(reason="May fix in the future: need model-specific fixes") @@ -1262,7 +1270,7 @@ class GenerationTesterMixin: config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) # NOTE: assisted generation only works with cache on at the moment. - if not hasattr(config, "use_cache"): + if not hasattr(config.get_text_config(), "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") config.is_decoder = True @@ -1368,6 +1376,10 @@ class GenerationTesterMixin: "prophetnet", "seamlessm4t", "clvp", + "mllama", # special cache sizes + "blip2", # overridden `generate()` + "instructblip", + "instructblipvideo", ] ): self.skipTest(reason="May fix in the future: need model-specific fixes") @@ -1376,7 +1388,7 @@ class GenerationTesterMixin: config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) # NOTE: assisted generation only works with cache on at the moment. - if not hasattr(config, "use_cache"): + if not hasattr(config.get_text_config(), "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") config.is_decoder = True @@ -1570,7 +1582,7 @@ class GenerationTesterMixin: config, inputs = self.model_tester.prepare_config_and_inputs_for_common() # If it doesn't support cache, pass the test - if not hasattr(config, "use_cache"): + if not hasattr(config.get_text_config(), "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") model = model_class(config).to(torch_device) @@ -1605,7 +1617,14 @@ class GenerationTesterMixin: # Encoder-Decoder checks if config.is_encoder_decoder: - encoder_num_attention_heads = config.encoder_attention_heads + # encoder-decoder models usually don't have text config + # below is needed only for Pix2Struct which we cannot modify now due to BC + config = config.get_text_config() + encoder_num_attention_heads = ( + config.encoder_attention_heads + if hasattr(config, "encoder_attention_heads") + else config.num_attention_heads + ) encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads batch_size, seq_length = inputs["decoder_input_ids"].shape for i in range(num_hidden_layers): @@ -1804,14 +1823,14 @@ class GenerationTesterMixin: def test_generate_continue_from_past_key_values(self): # Tests that we can continue generating from past key values, returned from a previous `generate` call for model_class in self.all_generative_model_classes: - if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]): + if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt", "mllama"]): self.skipTest(reason="Won't fix: old model with unique inputs/caches/other") if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]): self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility") config, inputs = self.model_tester.prepare_config_and_inputs_for_common() - if not hasattr(config, "use_cache"): + if not hasattr(config.get_text_config(), "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") # Let's make it always: @@ -2251,7 +2270,7 @@ class GenerationTesterMixin: config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) # NOTE: assisted generation only works with cache on at the moment. - if not hasattr(config, "use_cache"): + if not hasattr(config.get_text_config(), "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") config.use_cache = True config.is_decoder = True diff --git a/tests/models/aria/test_modeling_aria.py b/tests/models/aria/test_modeling_aria.py index a59a6ba07e..9fb57eeec9 100644 --- a/tests/models/aria/test_modeling_aria.py +++ b/tests/models/aria/test_modeling_aria.py @@ -82,14 +82,14 @@ class AriaVisionText2TextModelTester: moe_intermediate_size=4, moe_num_experts=4, moe_topk=2, - num_attention_heads=20, + num_attention_heads=8, num_experts_per_tok=3, num_hidden_layers=2, - num_key_value_heads=20, + num_key_value_heads=8, rope_theta=5000000, vocab_size=99, eos_token_id=2, - head_dim=2, + head_dim=4, ), is_training=True, vision_config=Idefics3VisionConfig( diff --git a/tests/models/mllama/test_modeling_mllama.py b/tests/models/mllama/test_modeling_mllama.py index 4e4c4636b7..40541fc827 100644 --- a/tests/models/mllama/test_modeling_mllama.py +++ b/tests/models/mllama/test_modeling_mllama.py @@ -29,6 +29,7 @@ from transformers import ( is_torch_available, is_vision_available, ) +from transformers.cache_utils import Cache from transformers.models.mllama.configuration_mllama import MllamaTextConfig from transformers.testing_utils import ( cleanup, @@ -378,6 +379,105 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester def test_offloaded_cache_implementation(self, cache_implementation): pass + @unittest.skip( + reason="Mllama cache type doesn't allow correct check on output `past_key_values` due to `Cache.crop()`" + ) + def test_contrastive_generate_dict_outputs_use_cache(self, assistant_type): + pass + + @unittest.skip(reason="Mllama can't do low memory due to `Cache.crop()`") + def test_contrastive_generate_low_memory(self, assistant_type): + pass + + @unittest.skip(reason="Mllama can't assisted decoding due to cache format and `Cache.crop()`") + def test_assisted_decoding_with_num_logits_to_keep(self): + pass + + @pytest.mark.generate + # overriden because mllama has special cache for self and cross attentions + def test_past_key_values_format(self): + # Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a + # standard KV cache format is important for a consistent API (and for advanced generation methods). + for model_class in self.all_generative_model_classes: + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + model = model_class(config).to(torch_device) + if "use_cache" not in inputs: + inputs["use_cache"] = True + outputs = model(**inputs) + + text_config = config.get_text_config() + num_hidden_layers = ( + getattr(text_config, "decoder_layers", None) + or getattr(text_config, "num_decoder_layers", None) + or text_config.num_hidden_layers + ) + num_attention_heads = getattr(text_config, "decoder_attention_heads", text_config.num_attention_heads) + embed_dim = getattr(text_config, "d_model", text_config.hidden_size) + per_head_embed_dim = embed_dim // num_attention_heads + + # some models have diffent num-head for query vs key/value so we need to assign correct value + # BUT only after `per_head_embed_dim` is set + num_attention_heads = ( + text_config.num_key_value_heads + if getattr(text_config, "num_key_value_heads", None) is not None + else num_attention_heads + ) + + past_kv = outputs["past_key_values"] + self.assertEqual(len(past_kv), num_hidden_layers) + batch_size, seq_length = inputs["input_ids"].shape + for i in range(num_hidden_layers): + self.assertEqual(len(past_kv[0]), 2) # K V for the decoder = 2 + if i in self.model_tester.text_config["cross_attention_layers"]: + self.assertEqual( + past_kv[i][0].shape, + (batch_size, num_attention_heads, self.model_tester.image_length, per_head_embed_dim), + ) + self.assertEqual( + past_kv[i][1].shape, + (batch_size, num_attention_heads, self.model_tester.image_length, per_head_embed_dim), + ) + else: + self.assertEqual( + past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim) + ) + self.assertEqual( + past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim) + ) + + # overriden because mllama has special cache for self and cross attentions + def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config): + self.assertIsInstance(decoder_past_key_values, Cache) + self.assertListEqual( + [isinstance(iter_past_key_values, tuple) for iter_past_key_values in decoder_past_key_values], + [True] * len(decoder_past_key_values), + ) + + for layer_idx, layer_past_key_values in enumerate(decoder_past_key_values): + if layer_idx in self.model_tester.text_config["cross_attention_layers"]: + expected_shape = ( + batch_size, + config.num_key_value_heads + if hasattr(config, "num_key_value_heads") + else config.num_attention_heads, + self.model_tester.image_length, + config.hidden_size // config.num_attention_heads, + ) + else: + # (batch, head, cache_length, head_features) + expected_shape = ( + batch_size, + config.num_key_value_heads + if hasattr(config, "num_key_value_heads") + else config.num_attention_heads, + cache_length, + config.hidden_size // config.num_attention_heads, + ) + # check shape key, value + self.assertListEqual([layer_past_key_values[0].shape], [expected_shape]) + self.assertListEqual([layer_past_key_values[1].shape], [expected_shape]) + def test_generate_text_only_with_cache(self): """ Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature diff --git a/tests/models/moshi/test_modeling_moshi.py b/tests/models/moshi/test_modeling_moshi.py index 9eb0eaa4d4..f637fb9efa 100644 --- a/tests/models/moshi/test_modeling_moshi.py +++ b/tests/models/moshi/test_modeling_moshi.py @@ -612,6 +612,18 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): def test_contrastive_generate_low_memory(self): pass + @unittest.skip( + "Moshi either needs deafult generation config or fix for fullgraph compile because it hardcodes SlidingWindowCache in custom generation loop." + ) + def test_greedy_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip( + "Moshi either needs deafult generation config or fix for fullgraph compile because it hardcodes SlidingWindowCache in custom generation loop." + ) + def test_beam_search_generate_dict_outputs_use_cache(self): + pass + @unittest.skip("Adapting this test is costly. `test_eager_matches_sdpa_generate` tests this already.") @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) @require_torch_sdpa diff --git a/tests/models/paligemma2/test_modeling_paligemma2.py b/tests/models/paligemma2/test_modeling_paligemma2.py index 4a87eb329d..451d4cc17a 100644 --- a/tests/models/paligemma2/test_modeling_paligemma2.py +++ b/tests/models/paligemma2/test_modeling_paligemma2.py @@ -16,6 +16,8 @@ import unittest +from parameterized import parameterized + from transformers import ( PaliGemmaConfig, PaliGemmaForConditionalGeneration, @@ -348,3 +350,40 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe @unittest.skip("Low memory will be removed soon so no need to fix it") def test_beam_search_low_memory(self): pass + + @parameterized.expand([("random",), ("same",)]) + @unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding") + def test_assisted_decoding_matches_greedy_search(self, assistant_type): + pass + + @unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding") + def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type): + pass + + @unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding") + def test_assisted_decoding_sample(self): + pass + + @unittest.skip("Gemma2 has HybridCache which is not compatible with dola decoding") + def test_dola_decoding_sample(self): + pass + + @unittest.skip("Gemma2 has HybridCache and doesn't support continue from past kv") + def test_generate_continue_from_past_key_values(self): + pass + + @unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate(self): + pass + + @unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate_low_memory(self): + pass + + @unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache") + def test_generate_with_static_cache(self): + pass