Idefics: enable generation tests (#34062)

* add idefics

* conflicts after merging main

* enable tests but need to fix some

* fix tests

* no print

* fix/skip some slow tests

* continue not skip

* rebasing broken smth, this is the fix
This commit is contained in:
Raushan Turganbay
2024-10-15 11:17:14 +02:00
committed by GitHub
parent dd4216b766
commit 23874f5948
10 changed files with 406 additions and 96 deletions

View File

@@ -726,14 +726,23 @@ def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_en
elif mask_length_diff > 0: elif mask_length_diff > 0:
model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1) model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1)
# Handle cross attention models
if "cross_attention_mask" in model_kwargs: if "cross_attention_mask" in model_kwargs:
# Mllama case is special and has another mask for cross attention model # Mllama case
cross_mask = model_kwargs["cross_attention_mask"] cross_mask = model_kwargs["cross_attention_mask"]
if mask_length_diff < 0: if mask_length_diff < 0:
model_kwargs["cross_attention_mask"] = cross_mask[:, :mask_length_diff] model_kwargs["cross_attention_mask"] = cross_mask[:, :mask_length_diff]
elif mask_length_diff > 0: elif mask_length_diff > 0:
new_mask = cross_mask[:, -1:, :, :].repeat(1, mask_length_diff, 1, 1) new_mask = cross_mask[:, -1:, :, :].repeat(1, mask_length_diff, 1, 1)
model_kwargs["cross_attention_mask"] = torch.cat([cross_mask, new_mask], dim=1) model_kwargs["cross_attention_mask"] = torch.cat([cross_mask, new_mask], dim=1)
elif "image_attention_mask" in model_kwargs:
# IDEFICS case
cross_mask = model_kwargs["image_attention_mask"]
if mask_length_diff < 0:
model_kwargs["image_attention_mask"] = cross_mask[:, :mask_length_diff]
elif mask_length_diff > 0:
new_mask = cross_mask[:, -1:, :].repeat(1, mask_length_diff, 1)
model_kwargs["image_attention_mask"] = torch.cat([cross_mask, new_mask], dim=1)
return model_kwargs return model_kwargs

View File

@@ -2005,6 +2005,7 @@ class GenerationMixin:
# generating the first new token or not, and we only want to use the embeddings for the first new token) # generating the first new token or not, and we only want to use the embeddings for the first new token)
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
model_kwargs["use_cache"] = True model_kwargs["use_cache"] = True
generation_config.use_cache = True
else: else:
model_kwargs["use_cache"] = generation_config.use_cache model_kwargs["use_cache"] = generation_config.use_cache
@@ -4299,7 +4300,8 @@ class GenerationMixin:
newly_added_length, newly_added_length,
is_decoder_attention=True, is_decoder_attention=True,
) )
else: # some (V)LLMs have hard requirement on SDPA and thus never return attn
elif outputs.attentions[0] is not None:
decoder_attentions = _split_model_outputs( decoder_attentions = _split_model_outputs(
decoder_attentions, decoder_attentions,
outputs.attentions, outputs.attentions,

View File

@@ -28,12 +28,12 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ... import PreTrainedModel
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import ModelOutput from ...modeling_outputs import ModelOutput
from ...modeling_utils import PretrainedConfig from ...modeling_utils import PretrainedConfig, PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
@@ -622,11 +622,9 @@ class IdeficsAttention(nn.Module):
query_states = self.q_layer_norm(query_states) query_states = self.q_layer_norm(query_states)
key_states = self.k_layer_norm(key_states) key_states = self.k_layer_norm(key_states)
causal_mask = attention_mask
if attention_mask is not None: if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577. # Reference: https://github.com/pytorch/pytorch/issues/112577.
@@ -638,13 +636,13 @@ class IdeficsAttention(nn.Module):
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=attention_mask, attn_mask=causal_mask,
dropout_p=self.dropout if self.training else 0.0, dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal, is_causal=is_causal,
) )
@@ -1490,7 +1488,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
return causal_mask return causal_mask
class IdeficsForVisionText2Text(IdeficsPreTrainedModel): class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"] _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
_tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"] _tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"]
@@ -1670,6 +1668,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
position_ids=None, position_ids=None,
pixel_values=None, pixel_values=None,
image_hidden_states=None, image_hidden_states=None,
image_attention_mask=None,
use_cache=None, use_cache=None,
cache_position=None, cache_position=None,
**kwargs, **kwargs,
@@ -1678,6 +1677,8 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
if past_key_values is not None: if past_key_values is not None:
if input_ids.shape[1] != cache_position.shape[0]: if input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position] input_ids = input_ids[:, cache_position]
if image_attention_mask is not None:
image_attention_mask = image_attention_mask[:, -input_ids.shape[1] :]
if attention_mask is not None and position_ids is None: if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation # create position_ids on the fly for batch generation
@@ -1696,7 +1697,8 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
model_inputs["perceiver_embeddings"] = image_hidden_states model_inputs["perceiver_embeddings"] = image_hidden_states
else: else:
model_inputs["image_encoder_embeddings"] = image_hidden_states model_inputs["image_encoder_embeddings"] = image_hidden_states
pixel_values = None else:
model_inputs["pixel_values"] = pixel_values
model_inputs.update( model_inputs.update(
{ {
@@ -1706,21 +1708,13 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
"cache_position": cache_position, "cache_position": cache_position,
"position_ids": position_ids, "position_ids": position_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"pixel_values": pixel_values, "image_attention_mask": image_attention_mask,
"image_attention_mask": kwargs.get("image_attention_mask", None),
"interpolate_pos_encoding": kwargs.get("interpolate_pos_encoding", False), "interpolate_pos_encoding": kwargs.get("interpolate_pos_encoding", False),
} }
) )
return model_inputs return model_inputs
@staticmethod
def _expand_inputs_for_generation(
*args,
**model_kwargs,
):
return expand_inputs_for_generation(*args, **model_kwargs)
def _update_model_kwargs_for_generation( def _update_model_kwargs_for_generation(
self, self,
outputs: ModelOutput, outputs: ModelOutput,
@@ -1738,7 +1732,10 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
if "image_attention_mask" in model_kwargs: if "image_attention_mask" in model_kwargs:
image_attention_mask = model_kwargs["image_attention_mask"] image_attention_mask = model_kwargs["image_attention_mask"]
last_mask = image_attention_mask[:, -1, :].unsqueeze(1) last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
model_kwargs["image_attention_mask"] = last_mask if model_kwargs.get("use_cache", True):
model_kwargs["image_attention_mask"] = last_mask
else:
model_kwargs["image_attention_mask"] = torch.cat([image_attention_mask, last_mask], dim=1)
# Get the precomputed image_hidden_states # Get the precomputed image_hidden_states
model_kwargs["image_hidden_states"] = outputs.image_hidden_states model_kwargs["image_hidden_states"] = outputs.image_hidden_states

View File

@@ -1427,6 +1427,7 @@ class Idefics2Model(Idefics2PreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
@@ -1657,35 +1658,19 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin)
past_key_values=None, past_key_values=None,
attention_mask=None, attention_mask=None,
inputs_embeds=None, inputs_embeds=None,
cache_position=None,
pixel_values=None,
pixel_attention_mask=None,
image_hidden_states=None,
num_logits_to_keep=None, num_logits_to_keep=None,
**kwargs, **kwargs,
): ):
past_length = 0 # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Omit tokens covered by past_key_values
if past_key_values is not None: if past_key_values is not None:
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore if inputs_embeds is not None: # Exception 1
past_length = past_key_values.get_seq_length() input_ids = input_ids[:, -cache_position.shape[0] :]
max_cache_length = past_key_values.get_max_cache_shape() elif input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and past_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None: if attention_mask is not None and position_ids is None:
@@ -1696,21 +1681,22 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin)
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_length == 0: # but IDEFICS requires noth ids and embeds to be present
model_inputs = {"inputs_embeds": inputs_embeds} if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids}
else: else:
model_inputs = {"input_ids": input_ids} # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if num_logits_to_keep is not None: if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep model_inputs["num_logits_to_keep"] = num_logits_to_keep
image_hidden_states = kwargs.get("image_hidden_states", None)
if image_hidden_states is not None: if image_hidden_states is not None:
pixel_values = None pixel_values = None
pixel_attention_mask = None pixel_attention_mask = None
else: else:
pixel_values = kwargs.get("pixel_values", None) pixel_values = pixel_values
pixel_attention_mask = kwargs.get("pixel_attention_mask", None) pixel_attention_mask = pixel_attention_mask
model_inputs.update( model_inputs.update(
{ {
"position_ids": position_ids, "position_ids": position_ids,

View File

@@ -24,7 +24,8 @@ from torch.nn import CrossEntropyLoss
from ... import PreTrainedModel from ... import PreTrainedModel
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_outputs import BaseModelOutput, ModelOutput
from ...utils import ( from ...utils import (
@@ -953,6 +954,8 @@ class Idefics3Model(Idefics3PreTrainedModel):
past_seen_tokens = 0 past_seen_tokens = 0
if use_cache: if use_cache:
if past_key_values is None:
past_key_values = DynamicCache()
past_seen_tokens = past_key_values.get_seq_length() past_seen_tokens = past_key_values.get_seq_length()
if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
@@ -1019,6 +1022,7 @@ class Idefics3Model(Idefics3PreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
@@ -1040,7 +1044,7 @@ class Idefics3Model(Idefics3PreTrainedModel):
"""The Idefics3 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top. """, """The Idefics3 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top. """,
IDEFICS3_START_DOCSTRING, IDEFICS3_START_DOCSTRING,
) )
class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel): class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.__init__ with Idefics2->Idefics3 # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.__init__ with Idefics2->Idefics3
@@ -1245,35 +1249,19 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel):
past_key_values=None, past_key_values=None,
attention_mask=None, attention_mask=None,
inputs_embeds=None, inputs_embeds=None,
cache_position=None,
pixel_values=None,
pixel_attention_mask=None,
image_hidden_states=None,
num_logits_to_keep=None, num_logits_to_keep=None,
**kwargs, **kwargs,
): ):
past_length = 0 # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Omit tokens covered by past_key_values
if past_key_values is not None: if past_key_values is not None:
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore if inputs_embeds is not None: # Exception 1
past_length = past_key_values.get_seq_length() input_ids = input_ids[:, -cache_position.shape[0] :]
max_cache_length = past_key_values.get_max_cache_shape() elif input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and past_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None: if attention_mask is not None and position_ids is None:
@@ -1284,21 +1272,22 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_length == 0: # but IDEFICS requires noth ids and embeds to be present
model_inputs = {"inputs_embeds": inputs_embeds} if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids}
else: else:
model_inputs = {"input_ids": input_ids} # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if num_logits_to_keep is not None: if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep model_inputs["num_logits_to_keep"] = num_logits_to_keep
image_hidden_states = kwargs.get("image_hidden_states", None)
if image_hidden_states is not None: if image_hidden_states is not None:
pixel_values = None pixel_values = None
pixel_attention_mask = None pixel_attention_mask = None
else: else:
pixel_values = kwargs.get("pixel_values", None) pixel_values = pixel_values
pixel_attention_mask = kwargs.get("pixel_attention_mask", None) pixel_attention_mask = pixel_attention_mask
model_inputs.update( model_inputs.update(
{ {
"position_ids": position_ids, "position_ids": position_ids,

View File

@@ -153,7 +153,11 @@ class GenerationTesterMixin:
# This is a band-aid for VLM models, to ensure they don't generate image/video tokens which would cause them # This is a band-aid for VLM models, to ensure they don't generate image/video tokens which would cause them
# to crash. On pretrained models this isn't a risk, as they are trained to not generate these tokens. # to crash. On pretrained models this isn't a risk, as they are trained to not generate these tokens.
if config is not None: if config is not None:
image_token_index = config.image_token_index if hasattr(config, "image_token_index") else None image_token_index = (
config.image_token_index
if getattr(config, "image_token_index", None) is not None
else getattr(config, "image_token_id", None)
)
video_token_index = config.video_token_index if hasattr(config, "video_token_index") else None video_token_index = config.video_token_index if hasattr(config, "video_token_index") else None
if image_token_index is not None and image_token_index < config.get_text_config().vocab_size: if image_token_index is not None and image_token_index < config.get_text_config().vocab_size:
logits_processor_kwargs["bad_words_ids"].append([image_token_index]) logits_processor_kwargs["bad_words_ids"].append([image_token_index])
@@ -1496,13 +1500,14 @@ class GenerationTesterMixin:
if "past_key_values" not in outputs: if "past_key_values" not in outputs:
self.skipTest(reason="This model doesn't return `past_key_values`") self.skipTest(reason="This model doesn't return `past_key_values`")
text_config = config.get_text_config()
num_hidden_layers = ( num_hidden_layers = (
getattr(config, "decoder_layers", None) getattr(text_config, "decoder_layers", None)
or getattr(config, "num_decoder_layers", None) or getattr(text_config, "num_decoder_layers", None)
or config.num_hidden_layers or text_config.num_hidden_layers
) )
num_attention_heads = getattr(config, "decoder_attention_heads", config.num_attention_heads) num_attention_heads = getattr(text_config, "decoder_attention_heads", text_config.num_attention_heads)
embed_dim = getattr(config, "d_model", config.hidden_size) embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
per_head_embed_dim = embed_dim // num_attention_heads per_head_embed_dim = embed_dim // num_attention_heads
past_kv = outputs["past_key_values"] past_kv = outputs["past_key_values"]

View File

@@ -14,8 +14,10 @@
# limitations under the License. # limitations under the License.
"""Testing suite for the PyTorch Idefics model.""" """Testing suite for the PyTorch Idefics model."""
import inspect
import unittest import unittest
import pytest
from parameterized import parameterized from parameterized import parameterized
from transformers import BitsAndBytesConfig, IdeficsConfig, is_torch_available, is_vision_available from transformers import BitsAndBytesConfig, IdeficsConfig, is_torch_available, is_vision_available
@@ -31,6 +33,7 @@ from transformers.testing_utils import (
) )
from transformers.utils import cached_property from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
from ...test_pipeline_mixin import PipelineTesterMixin from ...test_pipeline_mixin import PipelineTesterMixin
@@ -318,6 +321,12 @@ class IdeficsModelTester:
def test_eager_matches_sdpa_inference(self, torch_dtype: str): def test_eager_matches_sdpa_inference(self, torch_dtype: str):
self.skipTest(reason="Idefics has a hard requirement on SDPA, skipping this test") self.skipTest(reason="Idefics has a hard requirement on SDPA, skipping this test")
@require_torch_sdpa
@slow
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
def test_eager_matches_sdpa_generate(self):
self.skipTest(reason="Idefics has a hard requirement on SDPA, skipping this test")
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required") @unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
@require_torch @require_torch
@@ -580,8 +589,9 @@ class IdeficsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
@unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required") @unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required")
@require_torch @require_torch
class IdeficsForVisionText2TextTest(IdeficsModelTest, unittest.TestCase): class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (IdeficsForVisionText2Text,) if is_torch_available() else () all_model_classes = (IdeficsForVisionText2Text,) if is_torch_available() else ()
all_generative_model_classes = (IdeficsForVisionText2Text,) if is_torch_available() else ()
def setUp(self): def setUp(self):
self.model_tester = IdeficsModelTester( self.model_tester = IdeficsModelTester(
@@ -590,6 +600,182 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, unittest.TestCase):
) )
self.config_tester = ConfigTester(self, config_class=IdeficsConfig, hidden_size=37) self.config_tester = ConfigTester(self, config_class=IdeficsConfig, hidden_size=37)
@pytest.mark.generate
def test_left_padding_compatibility(self):
"""Overwrite because IDEFICS needs image attention mask to be also padded"""
# NOTE: left-padding results in small numerical differences. This is expected.
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
def _prepare_model_kwargs(input_ids, attention_mask, image_attention_mask, signature):
model_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"image_attention_mask": image_attention_mask,
}
if "position_ids" in signature:
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
model_kwargs["position_ids"] = position_ids
if "cache_position" in signature:
cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
model_kwargs["cache_position"] = cache_position
return model_kwargs
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
input_ids = inputs_dict.pop("input_ids")
attention_mask = inputs_dict.pop("attention_mask")
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
image_attention_mask = inputs_dict.pop("image_attention_mask", None)
model = model_class(config).to(torch_device).eval()
signature = inspect.signature(model.forward).parameters.keys()
# no cache as some models require special cache classes to be init outside forward
model.generation_config.use_cache = False
# Without padding
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, image_attention_mask, signature)
next_logits_wo_padding = model(**model_kwargs, **inputs_dict).logits[:, -1, :]
# With left-padding (length 32)
# can hardcode pad_token to be 0 as we'll do attn masking anyway
pad_token_id = (
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
)
pad_size = (input_ids.shape[0], 32)
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
padded_input_ids = torch.cat((padding, input_ids), dim=1)
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
pad_size_img = (input_ids.shape[0], 32, image_attention_mask.shape[-1])
extra_img_mask = torch.zeros(pad_size_img, dtype=image_attention_mask.dtype, device=torch_device)
padded_image_attention_mask = torch.cat([extra_img_mask, image_attention_mask], dim=1)
model_kwargs = _prepare_model_kwargs(
padded_input_ids, padded_attention_mask, padded_image_attention_mask, signature
)
next_logits_with_padding = model(**model_kwargs, **inputs_dict).logits[:, -1, :]
# They should result in very similar logits
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-5))
@pytest.mark.generate
def test_generate_continue_from_past_key_values(self):
"""Overwrite because IDEFICS needs image attention mask to be also processed"""
# Tests that we can continue generating from past key values, returned from a previous `generate` call
for model_class in self.all_generative_model_classes:
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
# Let's make it always:
# 1. use cache (for obvious reasons)
# 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which
# would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the
# continuation would force it to generate beyond an EOS token)
# 3. ignore `token_type_ids` for simplicity
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
# active by default on some models
# 5. ignore `encoder_no_repeat_ngram_size`, which is set by default in some encoder-decoder models. When
# we use their decoder as a stand-alone model, `encoder_no_repeat_ngram_size` actually prevents
# repetition exclusively from the prompt. This test relies on comparing one call vs 2 calls
# with cache, what is considered a prompt is different in the two cases.
model = model_class(config).to(torch_device)
model.eval()
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
model.generation_config.forced_eos_token_id = None
model.generation_config.encoder_no_repeat_ngram_size = 0
model.generation_config.use_cache = True
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=4, return_dict_in_generate=True)
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
# inputs may need to be tweaked across `generate` calls (like the attention mask).
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True)
# Continue from the tokens generated above, preparing the inputs accordingly
inputs["past_key_values"] = outputs_cached.past_key_values
new_attention_len = outputs_cached.sequences.shape[-1]
inputs["input_ids"] = outputs_cached.sequences
if "attention_mask" in inputs:
inputs["attention_mask"] = torch.nn.functional.pad(
inputs["attention_mask"],
(0, new_attention_len - inputs["attention_mask"].shape[1]),
mode="constant",
value=1,
)
if "image_attention_mask" in inputs:
inputs["image_attention_mask"] = inputs["image_attention_mask"][:, -1:, :]
outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=1, return_dict_in_generate=True)
# The two sets of generated text and past kv should be equal to each other
self.assertListEqual(outputs.sequences.tolist(), outputs_cached.sequences.tolist())
for layer_idx in range(len(outputs_cached.past_key_values)):
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
self.assertTrue(
torch.allclose(
outputs.past_key_values[layer_idx][kv_idx],
outputs_cached.past_key_values[layer_idx][kv_idx],
)
)
@pytest.mark.generate
def test_generate_without_input_ids(self):
"""Overwrite because IDEFICS needs image attention mask to be also processed and requires image at input always."""
config, input_dict = self.prepare_config_and_inputs_for_generate()
pixel_values = input_dict["pixel_values"]
image_attention_mask = input_dict["image_attention_mask"][:, -1:, :]
# hack in case they are equal, otherwise the attn mask will be [0]
if config.bos_token_id == config.pad_token_id:
config.pad_token_id = None
for model_class in self.all_generative_model_classes:
model = model_class(config).to(torch_device)
model.eval()
output_ids_generate = model.generate(
pixel_values=pixel_values,
image_attention_mask=image_attention_mask,
do_sample=False,
max_new_tokens=self.max_new_tokens,
remove_invalid_values=True,
)
self.assertIsNotNone(output_ids_generate)
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
"""
Overwrite from generation tests because Idefics has only SDPA layers.
Do not skip because we still want generation tests to run. Rather we can remove checks for shape.
"""
pass
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
def test_contrastive_generate(self):
pass
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip(reason="We only test the model that takes in multiple images")
def test_custom_4d_attention_mask(self):
pass
@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
def test_generate_compile_fullgraph(self):
pass
@unittest.skip(reason="We only test the model that takes in multiple images") @unittest.skip(reason="We only test the model that takes in multiple images")
def test_model(self): def test_model(self):
pass pass

View File

@@ -19,6 +19,7 @@ import gc
import unittest import unittest
from io import BytesIO from io import BytesIO
import pytest
import requests import requests
from transformers import ( from transformers import (
@@ -96,7 +97,7 @@ class Idefics2VisionText2TextModelTester:
"pad_token_id": 0, # None in the original configuration_mistral, we set it to the unk_token_id "pad_token_id": 0, # None in the original configuration_mistral, we set it to the unk_token_id
"bos_token_id": 1, "bos_token_id": 1,
"eos_token_id": 2, "eos_token_id": 2,
"image_token_id": 32_001, "image_token_id": 99,
"tie_word_embeddings": False, "tie_word_embeddings": False,
"rope_theta": 10000.0, "rope_theta": 10000.0,
"sliding_window": 32, "sliding_window": 32,
@@ -334,6 +335,7 @@ class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
""" """
all_model_classes = (Idefics2ForConditionalGeneration,) if is_torch_available() else () all_model_classes = (Idefics2ForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (Idefics2ForConditionalGeneration,) if is_torch_available() else ()
fx_compatible = False fx_compatible = False
test_pruning = False test_pruning = False
test_resize_embeddings = True test_resize_embeddings = True
@@ -356,6 +358,72 @@ class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
def test_flash_attn_2_inference_padding_right(self): def test_flash_attn_2_inference_padding_right(self):
pass pass
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
def test_contrastive_generate(self):
pass
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip(
reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates"
)
def test_prompt_lookup_decoding_matches_greedy_search(self):
pass
@unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@pytest.mark.generate
def test_generate_from_inputs_embeds_decoder_only(self):
# overwrite because IDEFICS needs ids and embeds at the input to be not None
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
# Ignore:
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
# which would cause a mismatch),
config.pad_token_id = config.eos_token_id = -1
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
input_ids = inputs_dict.pop("input_ids")
# Traditional way of generating text
outputs_from_ids = model.generate(
input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True
)
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
inputs_embeds = model.get_input_embeddings()(input_ids)
outputs_from_embeds = model.generate(
input_ids,
inputs_embeds=inputs_embeds,
max_new_tokens=5,
return_dict_in_generate=True,
output_scores=True,
)
self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist())
# But if we pass different inputs_embeds, we should get different outputs (the output text may be the
# same, but the logits will almost surely be different)
random_embeds = torch.rand_like(inputs_embeds)
outputs_from_rand_embeds = model.generate(
input_ids,
inputs_embeds=random_embeds,
max_new_tokens=5,
return_dict_in_generate=True,
output_scores=True,
)
for i in range(len(outputs_from_rand_embeds.scores)):
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
# We need to override as we need to prepare such that the image token is the last token # We need to override as we need to prepare such that the image token is the last token
def test_resize_tokens_embeddings(self): def test_resize_tokens_embeddings(self):
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common() (original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()

View File

@@ -19,6 +19,7 @@ import gc
import unittest import unittest
from io import BytesIO from io import BytesIO
import pytest
import requests import requests
from transformers import ( from transformers import (
@@ -321,6 +322,7 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
""" """
all_model_classes = (Idefics3ForConditionalGeneration,) if is_torch_available() else () all_model_classes = (Idefics3ForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (Idefics3ForConditionalGeneration,) if is_torch_available() else ()
fx_compatible = False fx_compatible = False
test_pruning = False test_pruning = False
test_resize_embeddings = True test_resize_embeddings = True
@@ -343,6 +345,72 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
def test_flash_attn_2_inference_padding_right(self): def test_flash_attn_2_inference_padding_right(self):
pass pass
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
def test_contrastive_generate(self):
pass
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip(
reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates"
)
def test_prompt_lookup_decoding_matches_greedy_search(self):
pass
@unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@pytest.mark.generate
def test_generate_from_inputs_embeds_decoder_only(self):
# overwrite because IDEFICS needs ids and embeds at the input to be not None
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
# Ignore:
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
# which would cause a mismatch),
config.pad_token_id = config.eos_token_id = -1
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
input_ids = inputs_dict.pop("input_ids")
# Traditional way of generating text
outputs_from_ids = model.generate(
input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True
)
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
inputs_embeds = model.get_input_embeddings()(input_ids)
outputs_from_embeds = model.generate(
input_ids,
inputs_embeds=inputs_embeds,
max_new_tokens=5,
return_dict_in_generate=True,
output_scores=True,
)
self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist())
# But if we pass different inputs_embeds, we should get different outputs (the output text may be the
# same, but the logits will almost surely be different)
random_embeds = torch.rand_like(inputs_embeds)
outputs_from_rand_embeds = model.generate(
input_ids,
inputs_embeds=random_embeds,
max_new_tokens=5,
return_dict_in_generate=True,
output_scores=True,
)
for i in range(len(outputs_from_rand_embeds.scores)):
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
# We need to override as we need to prepare such that the image token is the last token # We need to override as we need to prepare such that the image token is the last token
def test_resize_tokens_embeddings(self): def test_resize_tokens_embeddings(self):
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common() (original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()

View File

@@ -4768,7 +4768,7 @@ class ModelTesterMixin:
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# TODO: to change it in the future with other relevant auto classes # TODO: to change it in the future with other relevant auto classes
fa2_model = AutoModelForCausalLM.from_config( fa2_model = model_class._from_config(
config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16 config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16
).to(torch_device) ).to(torch_device)
@@ -4789,7 +4789,7 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
fa2_model.save_pretrained(tmpdirname) fa2_model.save_pretrained(tmpdirname)
model_from_pretrained = AutoModelForCausalLM.from_pretrained(tmpdirname) model_from_pretrained = model_class.from_pretrained(tmpdirname)
self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2") self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2")