Idefics: enable generation tests (#34062)
* add idefics * conflicts after merging main * enable tests but need to fix some * fix tests * no print * fix/skip some slow tests * continue not skip * rebasing broken smth, this is the fix
This commit is contained in:
committed by
GitHub
parent
dd4216b766
commit
23874f5948
@@ -153,7 +153,11 @@ class GenerationTesterMixin:
|
||||
# This is a band-aid for VLM models, to ensure they don't generate image/video tokens which would cause them
|
||||
# to crash. On pretrained models this isn't a risk, as they are trained to not generate these tokens.
|
||||
if config is not None:
|
||||
image_token_index = config.image_token_index if hasattr(config, "image_token_index") else None
|
||||
image_token_index = (
|
||||
config.image_token_index
|
||||
if getattr(config, "image_token_index", None) is not None
|
||||
else getattr(config, "image_token_id", None)
|
||||
)
|
||||
video_token_index = config.video_token_index if hasattr(config, "video_token_index") else None
|
||||
if image_token_index is not None and image_token_index < config.get_text_config().vocab_size:
|
||||
logits_processor_kwargs["bad_words_ids"].append([image_token_index])
|
||||
@@ -1496,13 +1500,14 @@ class GenerationTesterMixin:
|
||||
if "past_key_values" not in outputs:
|
||||
self.skipTest(reason="This model doesn't return `past_key_values`")
|
||||
|
||||
text_config = config.get_text_config()
|
||||
num_hidden_layers = (
|
||||
getattr(config, "decoder_layers", None)
|
||||
or getattr(config, "num_decoder_layers", None)
|
||||
or config.num_hidden_layers
|
||||
getattr(text_config, "decoder_layers", None)
|
||||
or getattr(text_config, "num_decoder_layers", None)
|
||||
or text_config.num_hidden_layers
|
||||
)
|
||||
num_attention_heads = getattr(config, "decoder_attention_heads", config.num_attention_heads)
|
||||
embed_dim = getattr(config, "d_model", config.hidden_size)
|
||||
num_attention_heads = getattr(text_config, "decoder_attention_heads", text_config.num_attention_heads)
|
||||
embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
|
||||
per_head_embed_dim = embed_dim // num_attention_heads
|
||||
|
||||
past_kv = outputs["past_key_values"]
|
||||
|
||||
Reference in New Issue
Block a user