VLM: enable skipped tests (#35746)
* fix cached tests * fix some tests * fix pix2struct * fix
This commit is contained in:
committed by
GitHub
parent
d6897b46bd
commit
8fc6ecba4f
@@ -579,6 +579,9 @@ BLIP_2_INPUTS_DOCSTRING = r"""
|
|||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to interpolate the pre-trained position encodings.
|
Whether to interpolate the pre-trained position encodings.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||||
|
`past_key_values`).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
BLIP2_IMAGE_TEXT_RETRIEVAL_INPUTS_DOCSTRING = r"""
|
BLIP2_IMAGE_TEXT_RETRIEVAL_INPUTS_DOCSTRING = r"""
|
||||||
@@ -2094,6 +2097,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
|||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
interpolate_pos_encoding: bool = False,
|
interpolate_pos_encoding: bool = False,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]:
|
) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
@@ -2217,6 +2221,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
use_cache=use_cache,
|
||||||
)
|
)
|
||||||
logits = outputs.logits if return_dict else outputs[0]
|
logits = outputs.logits if return_dict else outputs[0]
|
||||||
loss = None
|
loss = None
|
||||||
@@ -2242,6 +2247,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=True, # toggle for easier access to loss/logits below
|
return_dict=True, # toggle for easier access to loss/logits below
|
||||||
labels=labels,
|
labels=labels,
|
||||||
|
use_cache=use_cache,
|
||||||
)
|
)
|
||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
logits = outputs.logits
|
logits = outputs.logits
|
||||||
|
|||||||
@@ -441,6 +441,9 @@ INSTRUCTBLIP_INPUTS_DOCSTRING = r"""
|
|||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to interpolate the pre-trained position encodings.
|
Whether to interpolate the pre-trained position encodings.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||||
|
`past_key_values`).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@@ -1375,6 +1378,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
|||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
interpolate_pos_encoding: bool = False,
|
interpolate_pos_encoding: bool = False,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
) -> Union[Tuple, InstructBlipForConditionalGenerationModelOutput]:
|
) -> Union[Tuple, InstructBlipForConditionalGenerationModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
@@ -1485,6 +1489,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
use_cache=use_cache,
|
||||||
)
|
)
|
||||||
logits = outputs.logits if return_dict else outputs[0]
|
logits = outputs.logits if return_dict else outputs[0]
|
||||||
loss = None
|
loss = None
|
||||||
@@ -1510,6 +1515,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
|
use_cache=use_cache,
|
||||||
)
|
)
|
||||||
loss = outputs.loss if return_dict else outputs[0]
|
loss = outputs.loss if return_dict else outputs[0]
|
||||||
logits = outputs.logits if return_dict else outputs[1]
|
logits = outputs.logits if return_dict else outputs[1]
|
||||||
|
|||||||
@@ -1265,6 +1265,9 @@ INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = r"""
|
|||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to interpolate the pre-trained position encodings.
|
Whether to interpolate the pre-trained position encodings.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||||
|
`past_key_values`).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@@ -1369,6 +1372,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
|||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
interpolate_pos_encoding: bool = False,
|
interpolate_pos_encoding: bool = False,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
|
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
@@ -1512,6 +1516,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
use_cache=use_cache,
|
||||||
)
|
)
|
||||||
logits = outputs.logits if return_dict else outputs[0]
|
logits = outputs.logits if return_dict else outputs[0]
|
||||||
loss = None
|
loss = None
|
||||||
@@ -1537,6 +1542,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
|
use_cache=use_cache,
|
||||||
)
|
)
|
||||||
loss = outputs.loss if return_dict else outputs[0]
|
loss = outputs.loss if return_dict else outputs[0]
|
||||||
logits = outputs.logits if return_dict else outputs[1]
|
logits = outputs.logits if return_dict else outputs[1]
|
||||||
|
|||||||
@@ -188,6 +188,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
|||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
interpolate_pos_encoding: bool = False,
|
interpolate_pos_encoding: bool = False,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
|
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
```python
|
```python
|
||||||
@@ -322,6 +323,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
use_cache=use_cache,
|
||||||
)
|
)
|
||||||
logits = outputs.logits if return_dict else outputs[0]
|
logits = outputs.logits if return_dict else outputs[0]
|
||||||
loss = None
|
loss = None
|
||||||
@@ -347,6 +349,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
|
use_cache=use_cache,
|
||||||
)
|
)
|
||||||
loss = outputs.loss if return_dict else outputs[0]
|
loss = outputs.loss if return_dict else outputs[0]
|
||||||
logits = outputs.logits if return_dict else outputs[1]
|
logits = outputs.logits if return_dict else outputs[1]
|
||||||
|
|||||||
@@ -1694,6 +1694,7 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
|
|||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
|
cache_position=None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
):
|
):
|
||||||
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
||||||
@@ -1704,17 +1705,21 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
|
|||||||
attention_mask = input_ids.new_ones(input_shape)
|
attention_mask = input_ids.new_ones(input_shape)
|
||||||
|
|
||||||
position_ids = None
|
position_ids = None
|
||||||
|
if cache_position is None:
|
||||||
|
past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
||||||
|
|
||||||
# cut input_ids if past_key_values is used
|
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
position_ids = create_position_ids_from_input_ids(
|
position_ids = create_position_ids_from_input_ids(
|
||||||
input_ids,
|
input_ids,
|
||||||
padding_idx=self.config.pad_token_id,
|
padding_idx=self.config.pad_token_id,
|
||||||
past_key_values_length=0,
|
past_key_values_length=0,
|
||||||
)[:, -1:]
|
)
|
||||||
|
|
||||||
|
if input_ids.shape[1] != cache_position.shape[0]:
|
||||||
|
input_ids = input_ids[:, cache_position]
|
||||||
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||||
|
|
||||||
input_ids = input_ids[:, -1:]
|
|
||||||
# the image info. is already encoded into the past keys/values
|
|
||||||
image_embeds = None
|
image_embeds = None
|
||||||
image_embeds_position_mask = None
|
image_embeds_position_mask = None
|
||||||
elif image_embeds_position_mask is not None:
|
elif image_embeds_position_mask is not None:
|
||||||
|
|||||||
@@ -516,7 +516,7 @@ class GenerationTesterMixin:
|
|||||||
if self.has_attentions:
|
if self.has_attentions:
|
||||||
config._attn_implementation = "eager" # can't output attentions otherwise
|
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||||
|
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config.get_text_config(), "use_cache"):
|
||||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||||
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
|
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
|
||||||
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
|
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
|
||||||
@@ -651,7 +651,7 @@ class GenerationTesterMixin:
|
|||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
|
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config.get_text_config(), "use_cache"):
|
||||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||||
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
|
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
|
||||||
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
|
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
|
||||||
@@ -989,7 +989,7 @@ class GenerationTesterMixin:
|
|||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
|
|
||||||
# NOTE: contrastive search only works with cache on at the moment.
|
# NOTE: contrastive search only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config.get_text_config(), "use_cache"):
|
||||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
|
|
||||||
@@ -1018,7 +1018,7 @@ class GenerationTesterMixin:
|
|||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
|
|
||||||
# NOTE: contrastive search only works with cache on at the moment.
|
# NOTE: contrastive search only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config.get_text_config(), "use_cache"):
|
||||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
if self.has_attentions:
|
if self.has_attentions:
|
||||||
@@ -1060,7 +1060,7 @@ class GenerationTesterMixin:
|
|||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||||
|
|
||||||
# NOTE: contrastive search only works with cache on at the moment.
|
# NOTE: contrastive search only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config.get_text_config(), "use_cache"):
|
||||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||||
|
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
@@ -1179,6 +1179,10 @@ class GenerationTesterMixin:
|
|||||||
"prophetnet",
|
"prophetnet",
|
||||||
"seamlessm4t",
|
"seamlessm4t",
|
||||||
"clvp",
|
"clvp",
|
||||||
|
"mllama", # special cache sizes
|
||||||
|
"blip2", # overridden `generate()`
|
||||||
|
"instructblip",
|
||||||
|
"instructblipvideo",
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
||||||
@@ -1187,7 +1191,7 @@ class GenerationTesterMixin:
|
|||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||||
|
|
||||||
# NOTE: assisted generation only works with cache on at the moment.
|
# NOTE: assisted generation only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config.get_text_config(), "use_cache"):
|
||||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||||
|
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
@@ -1254,6 +1258,10 @@ class GenerationTesterMixin:
|
|||||||
"seamlessm4t",
|
"seamlessm4t",
|
||||||
"clvp",
|
"clvp",
|
||||||
"fuyu",
|
"fuyu",
|
||||||
|
"mllama", # special cache sizes
|
||||||
|
"blip2", # overridden `generate()`
|
||||||
|
"instructblip",
|
||||||
|
"instructblipvideo",
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
||||||
@@ -1262,7 +1270,7 @@ class GenerationTesterMixin:
|
|||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||||
|
|
||||||
# NOTE: assisted generation only works with cache on at the moment.
|
# NOTE: assisted generation only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config.get_text_config(), "use_cache"):
|
||||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||||
|
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
@@ -1368,6 +1376,10 @@ class GenerationTesterMixin:
|
|||||||
"prophetnet",
|
"prophetnet",
|
||||||
"seamlessm4t",
|
"seamlessm4t",
|
||||||
"clvp",
|
"clvp",
|
||||||
|
"mllama", # special cache sizes
|
||||||
|
"blip2", # overridden `generate()`
|
||||||
|
"instructblip",
|
||||||
|
"instructblipvideo",
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
||||||
@@ -1376,7 +1388,7 @@ class GenerationTesterMixin:
|
|||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||||
|
|
||||||
# NOTE: assisted generation only works with cache on at the moment.
|
# NOTE: assisted generation only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config.get_text_config(), "use_cache"):
|
||||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||||
|
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
@@ -1570,7 +1582,7 @@ class GenerationTesterMixin:
|
|||||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
# If it doesn't support cache, pass the test
|
# If it doesn't support cache, pass the test
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config.get_text_config(), "use_cache"):
|
||||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||||
|
|
||||||
model = model_class(config).to(torch_device)
|
model = model_class(config).to(torch_device)
|
||||||
@@ -1605,7 +1617,14 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
# Encoder-Decoder checks
|
# Encoder-Decoder checks
|
||||||
if config.is_encoder_decoder:
|
if config.is_encoder_decoder:
|
||||||
encoder_num_attention_heads = config.encoder_attention_heads
|
# encoder-decoder models usually don't have text config
|
||||||
|
# below is needed only for Pix2Struct which we cannot modify now due to BC
|
||||||
|
config = config.get_text_config()
|
||||||
|
encoder_num_attention_heads = (
|
||||||
|
config.encoder_attention_heads
|
||||||
|
if hasattr(config, "encoder_attention_heads")
|
||||||
|
else config.num_attention_heads
|
||||||
|
)
|
||||||
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
|
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
|
||||||
batch_size, seq_length = inputs["decoder_input_ids"].shape
|
batch_size, seq_length = inputs["decoder_input_ids"].shape
|
||||||
for i in range(num_hidden_layers):
|
for i in range(num_hidden_layers):
|
||||||
@@ -1804,14 +1823,14 @@ class GenerationTesterMixin:
|
|||||||
def test_generate_continue_from_past_key_values(self):
|
def test_generate_continue_from_past_key_values(self):
|
||||||
# Tests that we can continue generating from past key values, returned from a previous `generate` call
|
# Tests that we can continue generating from past key values, returned from a previous `generate` call
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
|
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt", "mllama"]):
|
||||||
self.skipTest(reason="Won't fix: old model with unique inputs/caches/other")
|
self.skipTest(reason="Won't fix: old model with unique inputs/caches/other")
|
||||||
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
|
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
|
||||||
self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility")
|
self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility")
|
||||||
|
|
||||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config.get_text_config(), "use_cache"):
|
||||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||||
|
|
||||||
# Let's make it always:
|
# Let's make it always:
|
||||||
@@ -2251,7 +2270,7 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||||
# NOTE: assisted generation only works with cache on at the moment.
|
# NOTE: assisted generation only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config.get_text_config(), "use_cache"):
|
||||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||||
config.use_cache = True
|
config.use_cache = True
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
|
|||||||
@@ -82,14 +82,14 @@ class AriaVisionText2TextModelTester:
|
|||||||
moe_intermediate_size=4,
|
moe_intermediate_size=4,
|
||||||
moe_num_experts=4,
|
moe_num_experts=4,
|
||||||
moe_topk=2,
|
moe_topk=2,
|
||||||
num_attention_heads=20,
|
num_attention_heads=8,
|
||||||
num_experts_per_tok=3,
|
num_experts_per_tok=3,
|
||||||
num_hidden_layers=2,
|
num_hidden_layers=2,
|
||||||
num_key_value_heads=20,
|
num_key_value_heads=8,
|
||||||
rope_theta=5000000,
|
rope_theta=5000000,
|
||||||
vocab_size=99,
|
vocab_size=99,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
head_dim=2,
|
head_dim=4,
|
||||||
),
|
),
|
||||||
is_training=True,
|
is_training=True,
|
||||||
vision_config=Idefics3VisionConfig(
|
vision_config=Idefics3VisionConfig(
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from transformers import (
|
|||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
from transformers.models.mllama.configuration_mllama import MllamaTextConfig
|
from transformers.models.mllama.configuration_mllama import MllamaTextConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
cleanup,
|
cleanup,
|
||||||
@@ -378,6 +379,105 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
|
|||||||
def test_offloaded_cache_implementation(self, cache_implementation):
|
def test_offloaded_cache_implementation(self, cache_implementation):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
reason="Mllama cache type doesn't allow correct check on output `past_key_values` due to `Cache.crop()`"
|
||||||
|
)
|
||||||
|
def test_contrastive_generate_dict_outputs_use_cache(self, assistant_type):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Mllama can't do low memory due to `Cache.crop()`")
|
||||||
|
def test_contrastive_generate_low_memory(self, assistant_type):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="Mllama can't assisted decoding due to cache format and `Cache.crop()`")
|
||||||
|
def test_assisted_decoding_with_num_logits_to_keep(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
|
# overriden because mllama has special cache for self and cross attentions
|
||||||
|
def test_past_key_values_format(self):
|
||||||
|
# Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a
|
||||||
|
# standard KV cache format is important for a consistent API (and for advanced generation methods).
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
model = model_class(config).to(torch_device)
|
||||||
|
if "use_cache" not in inputs:
|
||||||
|
inputs["use_cache"] = True
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
text_config = config.get_text_config()
|
||||||
|
num_hidden_layers = (
|
||||||
|
getattr(text_config, "decoder_layers", None)
|
||||||
|
or getattr(text_config, "num_decoder_layers", None)
|
||||||
|
or text_config.num_hidden_layers
|
||||||
|
)
|
||||||
|
num_attention_heads = getattr(text_config, "decoder_attention_heads", text_config.num_attention_heads)
|
||||||
|
embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
|
||||||
|
per_head_embed_dim = embed_dim // num_attention_heads
|
||||||
|
|
||||||
|
# some models have diffent num-head for query vs key/value so we need to assign correct value
|
||||||
|
# BUT only after `per_head_embed_dim` is set
|
||||||
|
num_attention_heads = (
|
||||||
|
text_config.num_key_value_heads
|
||||||
|
if getattr(text_config, "num_key_value_heads", None) is not None
|
||||||
|
else num_attention_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
past_kv = outputs["past_key_values"]
|
||||||
|
self.assertEqual(len(past_kv), num_hidden_layers)
|
||||||
|
batch_size, seq_length = inputs["input_ids"].shape
|
||||||
|
for i in range(num_hidden_layers):
|
||||||
|
self.assertEqual(len(past_kv[0]), 2) # K V for the decoder = 2
|
||||||
|
if i in self.model_tester.text_config["cross_attention_layers"]:
|
||||||
|
self.assertEqual(
|
||||||
|
past_kv[i][0].shape,
|
||||||
|
(batch_size, num_attention_heads, self.model_tester.image_length, per_head_embed_dim),
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
past_kv[i][1].shape,
|
||||||
|
(batch_size, num_attention_heads, self.model_tester.image_length, per_head_embed_dim),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.assertEqual(
|
||||||
|
past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
# overriden because mllama has special cache for self and cross attentions
|
||||||
|
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
|
||||||
|
self.assertIsInstance(decoder_past_key_values, Cache)
|
||||||
|
self.assertListEqual(
|
||||||
|
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in decoder_past_key_values],
|
||||||
|
[True] * len(decoder_past_key_values),
|
||||||
|
)
|
||||||
|
|
||||||
|
for layer_idx, layer_past_key_values in enumerate(decoder_past_key_values):
|
||||||
|
if layer_idx in self.model_tester.text_config["cross_attention_layers"]:
|
||||||
|
expected_shape = (
|
||||||
|
batch_size,
|
||||||
|
config.num_key_value_heads
|
||||||
|
if hasattr(config, "num_key_value_heads")
|
||||||
|
else config.num_attention_heads,
|
||||||
|
self.model_tester.image_length,
|
||||||
|
config.hidden_size // config.num_attention_heads,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# (batch, head, cache_length, head_features)
|
||||||
|
expected_shape = (
|
||||||
|
batch_size,
|
||||||
|
config.num_key_value_heads
|
||||||
|
if hasattr(config, "num_key_value_heads")
|
||||||
|
else config.num_attention_heads,
|
||||||
|
cache_length,
|
||||||
|
config.hidden_size // config.num_attention_heads,
|
||||||
|
)
|
||||||
|
# check shape key, value
|
||||||
|
self.assertListEqual([layer_past_key_values[0].shape], [expected_shape])
|
||||||
|
self.assertListEqual([layer_past_key_values[1].shape], [expected_shape])
|
||||||
|
|
||||||
def test_generate_text_only_with_cache(self):
|
def test_generate_text_only_with_cache(self):
|
||||||
"""
|
"""
|
||||||
Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature
|
Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature
|
||||||
|
|||||||
@@ -612,6 +612,18 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
def test_contrastive_generate_low_memory(self):
|
def test_contrastive_generate_low_memory(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
"Moshi either needs deafult generation config or fix for fullgraph compile because it hardcodes SlidingWindowCache in custom generation loop."
|
||||||
|
)
|
||||||
|
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
"Moshi either needs deafult generation config or fix for fullgraph compile because it hardcodes SlidingWindowCache in custom generation loop."
|
||||||
|
)
|
||||||
|
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@unittest.skip("Adapting this test is costly. `test_eager_matches_sdpa_generate` tests this already.")
|
@unittest.skip("Adapting this test is costly. `test_eager_matches_sdpa_generate` tests this already.")
|
||||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||||
@require_torch_sdpa
|
@require_torch_sdpa
|
||||||
|
|||||||
@@ -16,6 +16,8 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
PaliGemmaConfig,
|
PaliGemmaConfig,
|
||||||
PaliGemmaForConditionalGeneration,
|
PaliGemmaForConditionalGeneration,
|
||||||
@@ -348,3 +350,40 @@ class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
|||||||
@unittest.skip("Low memory will be removed soon so no need to fix it")
|
@unittest.skip("Low memory will be removed soon so no need to fix it")
|
||||||
def test_beam_search_low_memory(self):
|
def test_beam_search_low_memory(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@parameterized.expand([("random",), ("same",)])
|
||||||
|
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||||
|
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||||
|
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
|
||||||
|
def test_assisted_decoding_sample(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Gemma2 has HybridCache which is not compatible with dola decoding")
|
||||||
|
def test_dola_decoding_sample(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Gemma2 has HybridCache and doesn't support continue from past kv")
|
||||||
|
def test_generate_continue_from_past_key_values(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
|
||||||
|
def test_contrastive_generate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
|
||||||
|
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation")
|
||||||
|
def test_contrastive_generate_low_memory(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache")
|
||||||
|
def test_generate_with_static_cache(self):
|
||||||
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user