Idefics: enable generation tests (#34062)

* add idefics

* conflicts after merging main

* enable tests but need to fix some

* fix tests

* no print

* fix/skip some slow tests

* continue not skip

* rebasing broken smth, this is the fix
This commit is contained in:
Raushan Turganbay
2024-10-15 11:17:14 +02:00
committed by GitHub
parent dd4216b766
commit 23874f5948
10 changed files with 406 additions and 96 deletions

View File

@@ -153,7 +153,11 @@ class GenerationTesterMixin:
# This is a band-aid for VLM models, to ensure they don't generate image/video tokens which would cause them
# to crash. On pretrained models this isn't a risk, as they are trained to not generate these tokens.
if config is not None:
image_token_index = config.image_token_index if hasattr(config, "image_token_index") else None
image_token_index = (
config.image_token_index
if getattr(config, "image_token_index", None) is not None
else getattr(config, "image_token_id", None)
)
video_token_index = config.video_token_index if hasattr(config, "video_token_index") else None
if image_token_index is not None and image_token_index < config.get_text_config().vocab_size:
logits_processor_kwargs["bad_words_ids"].append([image_token_index])
@@ -1496,13 +1500,14 @@ class GenerationTesterMixin:
if "past_key_values" not in outputs:
self.skipTest(reason="This model doesn't return `past_key_values`")
text_config = config.get_text_config()
num_hidden_layers = (
getattr(config, "decoder_layers", None)
or getattr(config, "num_decoder_layers", None)
or config.num_hidden_layers
getattr(text_config, "decoder_layers", None)
or getattr(text_config, "num_decoder_layers", None)
or text_config.num_hidden_layers
)
num_attention_heads = getattr(config, "decoder_attention_heads", config.num_attention_heads)
embed_dim = getattr(config, "d_model", config.hidden_size)
num_attention_heads = getattr(text_config, "decoder_attention_heads", text_config.num_attention_heads)
embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
per_head_embed_dim = embed_dim // num_attention_heads
past_kv = outputs["past_key_values"]