VLM: enable skipped tests (#35746)
* fix cached tests * fix some tests * fix pix2struct * fix
This commit is contained in:
committed by
GitHub
parent
d6897b46bd
commit
8fc6ecba4f
@@ -579,6 +579,9 @@ BLIP_2_INPUTS_DOCSTRING = r"""
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
"""
|
||||
|
||||
BLIP2_IMAGE_TEXT_RETRIEVAL_INPUTS_DOCSTRING = r"""
|
||||
@@ -2094,6 +2097,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
use_cache: Optional[bool] = None,
|
||||
) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -2217,6 +2221,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
logits = outputs.logits if return_dict else outputs[0]
|
||||
loss = None
|
||||
@@ -2242,6 +2247,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True, # toggle for easier access to loss/logits below
|
||||
labels=labels,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
loss = outputs.loss
|
||||
logits = outputs.logits
|
||||
|
||||
@@ -441,6 +441,9 @@ INSTRUCTBLIP_INPUTS_DOCSTRING = r"""
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
"""
|
||||
|
||||
|
||||
@@ -1375,6 +1378,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
use_cache: Optional[bool] = None,
|
||||
) -> Union[Tuple, InstructBlipForConditionalGenerationModelOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@@ -1485,6 +1489,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
logits = outputs.logits if return_dict else outputs[0]
|
||||
loss = None
|
||||
@@ -1510,6 +1515,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
loss = outputs.loss if return_dict else outputs[0]
|
||||
logits = outputs.logits if return_dict else outputs[1]
|
||||
|
||||
@@ -1265,6 +1265,9 @@ INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r"""
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
"""
|
||||
|
||||
|
||||
@@ -1369,6 +1372,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
use_cache: Optional[bool] = None,
|
||||
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
@@ -1512,6 +1516,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
logits = outputs.logits if return_dict else outputs[0]
|
||||
loss = None
|
||||
@@ -1537,6 +1542,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
loss = outputs.loss if return_dict else outputs[0]
|
||||
logits = outputs.logits if return_dict else outputs[1]
|
||||
|
||||
@@ -188,6 +188,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
use_cache: Optional[bool] = None,
|
||||
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
|
||||
r"""
|
||||
```python
|
||||
@@ -322,6 +323,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
logits = outputs.logits if return_dict else outputs[0]
|
||||
loss = None
|
||||
@@ -347,6 +349,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
loss = outputs.loss if return_dict else outputs[0]
|
||||
logits = outputs.logits if return_dict else outputs[1]
|
||||
|
||||
@@ -1694,6 +1694,7 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
use_cache=None,
|
||||
cache_position=None,
|
||||
**model_kwargs,
|
||||
):
|
||||
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
||||
@@ -1704,17 +1705,21 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
position_ids = None
|
||||
if cache_position is None:
|
||||
past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
||||
|
||||
# cut input_ids if past_key_values is used
|
||||
if past_key_values is not None:
|
||||
position_ids = create_position_ids_from_input_ids(
|
||||
input_ids,
|
||||
padding_idx=self.config.pad_token_id,
|
||||
past_key_values_length=0,
|
||||
)[:, -1:]
|
||||
)
|
||||
|
||||
if input_ids.shape[1] != cache_position.shape[0]:
|
||||
input_ids = input_ids[:, cache_position]
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
input_ids = input_ids[:, -1:]
|
||||
# the image info. is already encoded into the past keys/values
|
||||
image_embeds = None
|
||||
image_embeds_position_mask = None
|
||||
elif image_embeds_position_mask is not None:
|
||||
|
||||
@@ -516,7 +516,7 @@ class GenerationTesterMixin:
|
||||
if self.has_attentions:
|
||||
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
if not hasattr(config.get_text_config(), "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
|
||||
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
|
||||
@@ -651,7 +651,7 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
if not hasattr(config.get_text_config(), "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
|
||||
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
|
||||
@@ -989,7 +989,7 @@ class GenerationTesterMixin:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
if not hasattr(config.get_text_config(), "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
config.is_decoder = True
|
||||
|
||||
@@ -1018,7 +1018,7 @@ class GenerationTesterMixin:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
if not hasattr(config.get_text_config(), "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
config.is_decoder = True
|
||||
if self.has_attentions:
|
||||
@@ -1060,7 +1060,7 @@ class GenerationTesterMixin:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||
|
||||
# NOTE: contrastive search only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
if not hasattr(config.get_text_config(), "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
|
||||
config.is_decoder = True
|
||||
@@ -1179,6 +1179,10 @@ class GenerationTesterMixin:
|
||||
"prophetnet",
|
||||
"seamlessm4t",
|
||||
"clvp",
|
||||
"mllama", # special cache sizes
|
||||
"blip2", # overridden `generate()`
|
||||
"instructblip",
|
||||
"instructblipvideo",
|
||||
]
|
||||
):
|
||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
||||
@@ -1187,7 +1191,7 @@ class GenerationTesterMixin:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
if not hasattr(config.get_text_config(), "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
|
||||
config.is_decoder = True
|
||||
@@ -1254,6 +1258,10 @@ class GenerationTesterMixin:
|
||||
"seamlessm4t",
|
||||
"clvp",
|
||||
"fuyu",
|
||||
"mllama", # special cache sizes
|
||||
"blip2", # overridden `generate()`
|
||||
"instructblip",
|
||||
"instructblipvideo",
|
||||
]
|
||||
):
|
||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
||||
@@ -1262,7 +1270,7 @@ class GenerationTesterMixin:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
if not hasattr(config.get_text_config(), "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
|
||||
config.is_decoder = True
|
||||
@@ -1368,6 +1376,10 @@ class GenerationTesterMixin:
|
||||
"prophetnet",
|
||||
"seamlessm4t",
|
||||
"clvp",
|
||||
"mllama", # special cache sizes
|
||||
"blip2", # overridden `generate()`
|
||||
"instructblip",
|
||||
"instructblipvideo",
|
||||
]
|
||||
):
|
||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
||||
@@ -1376,7 +1388,7 @@ class GenerationTesterMixin:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
if not hasattr(config.get_text_config(), "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
|
||||
config.is_decoder = True
|
||||
@@ -1570,7 +1582,7 @@ class GenerationTesterMixin:
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# If it doesn't support cache, pass the test
|
||||
if not hasattr(config, "use_cache"):
|
||||
if not hasattr(config.get_text_config(), "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
@@ -1605,7 +1617,14 @@ class GenerationTesterMixin:
|
||||
|
||||
# Encoder-Decoder checks
|
||||
if config.is_encoder_decoder:
|
||||
encoder_num_attention_heads = config.encoder_attention_heads
|
||||
# encoder-decoder models usually don't have text config
|
||||
# below is needed only for Pix2Struct which we cannot modify now due to BC
|
||||
config = config.get_text_config()
|
||||
encoder_num_attention_heads = (
|
||||
config.encoder_attention_heads
|
||||
if hasattr(config, "encoder_attention_heads")
|
||||
else config.num_attention_heads
|
||||
)
|
||||
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
|
||||
batch_size, seq_length = inputs["decoder_input_ids"].shape
|
||||
for i in range(num_hidden_layers):
|
||||
@@ -1804,14 +1823,14 @@ class GenerationTesterMixin:
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
# Tests that we can continue generating from past key values, returned from a previous `generate` call
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt", "mllama"]):
|
||||
self.skipTest(reason="Won't fix: old model with unique inputs/caches/other")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
|
||||
self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility")
|
||||
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
if not hasattr(config.get_text_config(), "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
|
||||
# Let's make it always:
|
||||
@@ -2251,7 +2270,7 @@ class GenerationTesterMixin:
|
||||
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config, "use_cache"):
|
||||
if not hasattr(config.get_text_config(), "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
@@ -82,14 +82,14 @@ class AriaVisionText2TextModelTester:
|
||||
moe_intermediate_size=4,
|
||||
moe_num_experts=4,
|
||||
moe_topk=2,
|
||||
num_attention_heads=20,
|
||||
num_attention_heads=8,
|
||||
num_experts_per_tok=3,
|
||||
num_hidden_layers=2,
|
||||
num_key_value_heads=20,
|
||||
num_key_value_heads=8,
|
||||
rope_theta=5000000,
|
||||
vocab_size=99,
|
||||
eos_token_id=2,
|
||||
head_dim=2,
|
||||
head_dim=4,
|
||||
),
|
||||
is_training=True,
|
||||
vision_config=Idefics3VisionConfig(
|
||||
|
||||
@@ -29,6 +29,7 @@ from transformers import (
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.models.mllama.configuration_mllama import MllamaTextConfig
|
||||
from transformers.testing_utils import (
|
||||
cleanup,
|
||||
@@ -378,6 +379,105 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
|
||||
def test_offloaded_cache_implementation(self, cache_implementation):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Mllama cache type doesn't allow correct check on output `past_key_values` due to `Cache.crop()`"
|
||||
)
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self, assistant_type):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Mllama can't do low memory due to `Cache.crop()`")
|
||||
def test_contrastive_generate_low_memory(self, assistant_type):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Mllama can't assisted decoding due to cache format and `Cache.crop()`")
|
||||
def test_assisted_decoding_with_num_logits_to_keep(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
# overriden because mllama has special cache for self and cross attentions
|
||||
def test_past_key_values_format(self):
|
||||
# Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a
|
||||
# standard KV cache format is important for a consistent API (and for advanced generation methods).
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
if "use_cache" not in inputs:
|
||||
inputs["use_cache"] = True
|
||||
outputs = model(**inputs)
|
||||
|
||||
text_config = config.get_text_config()
|
||||
num_hidden_layers = (
|
||||
getattr(text_config, "decoder_layers", None)
|
||||
or getattr(text_config, "num_decoder_layers", None)
|
||||
or text_config.num_hidden_layers
|
||||
)
|
||||
num_attention_heads = getattr(text_config, "decoder_attention_heads", text_config.num_attention_heads)
|
||||
embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
|
||||
per_head_embed_dim = embed_dim // num_attention_heads
|
||||
|
||||
# some models have diffent num-head for query vs key/value so we need to assign correct value
|
||||
# BUT only after `per_head_embed_dim` is set
|
||||
num_attention_heads = (
|
||||
text_config.num_key_value_heads
|
||||
if getattr(text_config, "num_key_value_heads", None) is not None
|
||||
else num_attention_heads
|
||||
)
|
||||
|
||||
past_kv = outputs["past_key_values"]
|
||||
self.assertEqual(len(past_kv), num_hidden_layers)
|
||||
batch_size, seq_length = inputs["input_ids"].shape
|
||||
for i in range(num_hidden_layers):
|
||||
self.assertEqual(len(past_kv[0]), 2) # K V for the decoder = 2
|
||||
if i in self.model_tester.text_config["cross_attention_layers"]:
|
||||
self.assertEqual(
|
||||
past_kv[i][0].shape,
|
||||
(batch_size, num_attention_heads, self.model_tester.image_length, per_head_embed_dim),
|
||||
)
|
||||
self.assertEqual(
|
||||
past_kv[i][1].shape,
|
||||
(batch_size, num_attention_heads, self.model_tester.image_length, per_head_embed_dim),
|
||||
)
|
||||
else:
|
||||
self.assertEqual(
|
||||
past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
||||
)
|
||||
self.assertEqual(
|
||||
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
||||
)
|
||||
|
||||
# overriden because mllama has special cache for self and cross attentions
|
||||
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
|
||||
self.assertIsInstance(decoder_past_key_values, Cache)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in decoder_past_key_values],
|
||||
[True] * len(decoder_past_key_values),
|
||||
)
|
||||
|
||||
for layer_idx, layer_past_key_values in enumerate(decoder_past_key_values):
|
||||
if layer_idx in self.model_tester.text_config["cross_attention_layers"]:
|
||||
expected_shape = (
|
||||
batch_size,
|
||||
config.num_key_value_heads
|
||||
if hasattr(config, "num_key_value_heads")
|
||||
else config.num_attention_heads,
|
||||
self.model_tester.image_length,
|
||||
config.hidden_size // config.num_attention_heads,
|
||||
)
|
||||
else:
|
||||
# (batch, head, cache_length, head_features)
|
||||
expected_shape = (
|
||||
batch_size,
|
||||
config.num_key_value_heads
|
||||
if hasattr(config, "num_key_value_heads")
|
||||
else config.num_attention_heads,
|
||||
cache_length,
|
||||
config.hidden_size // config.num_attention_heads,
|
||||
)
|
||||
# check shape key, value
|
||||
self.assertListEqual([layer_past_key_values[0].shape], [expected_shape])
|
||||
self.assertListEqual([layer_past_key_values[1].shape], [expected_shape])
|
||||
|
||||
def test_generate_text_only_with_cache(self):
|
||||
"""
|
||||
Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature
|
||||
|
||||
@@ -612,6 +612,18 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Moshi either needs deafult generation config or fix for fullgraph compile because it hardcodes SlidingWindowCache in custom generation loop."
|
||||
)
|
||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Moshi either needs deafult generation config or fix for fullgraph compile because it hardcodes SlidingWindowCache in custom generation loop."
|
||||
)
|
||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Adapting this test is costly. `test_eager_matches_sdpa_generate` tests this already.")
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
@require_torch_sdpa
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import (
|
||||
PaliGemmaConfig,
|
||||
PaliGemmaForConditionalGeneration,
|
||||
@@ -348,3 +350,40 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
||||
@unittest.skip("Low memory will be removed soon so no need to fix it")
|
||||
def test_beam_search_low_memory(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache which is not compatible with dola decoding")
|
||||
def test_dola_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support continue from past kv")
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
|
||||
def test_contrastive_generate(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache")
|
||||
def test_generate_with_static_cache(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user