Idefics: enable generation tests (#34062)

* add idefics

* conflicts after merging main

* enable tests but need to fix some

* fix tests

* no print

* fix/skip some slow tests

* continue not skip

* rebasing broken smth, this is the fix
This commit is contained in:
Raushan Turganbay
2024-10-15 11:17:14 +02:00
committed by GitHub
parent dd4216b766
commit 23874f5948
10 changed files with 406 additions and 96 deletions

View File

@@ -19,6 +19,7 @@ import gc
import unittest
from io import BytesIO
import pytest
import requests
from transformers import (
@@ -321,6 +322,7 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
"""
all_model_classes = (Idefics3ForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (Idefics3ForConditionalGeneration,) if is_torch_available() else ()
fx_compatible = False
test_pruning = False
test_resize_embeddings = True
@@ -343,6 +345,72 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
def test_flash_attn_2_inference_padding_right(self):
pass
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
def test_contrastive_generate(self):
pass
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip(reason="Contrastive search is not implemented for VLMs that do cross-attn")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip(
reason="Prompt lookup decoding needs a way to indicate `bad_word_ids` that should not be suggested as candidates"
)
def test_prompt_lookup_decoding_matches_greedy_search(self):
pass
@unittest.skip(reason=" FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@pytest.mark.generate
def test_generate_from_inputs_embeds_decoder_only(self):
# overwrite because IDEFICS needs ids and embeds at the input to be not None
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
# Ignore:
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
# which would cause a mismatch),
config.pad_token_id = config.eos_token_id = -1
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
input_ids = inputs_dict.pop("input_ids")
# Traditional way of generating text
outputs_from_ids = model.generate(
input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True
)
self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5))
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output)
inputs_embeds = model.get_input_embeddings()(input_ids)
outputs_from_embeds = model.generate(
input_ids,
inputs_embeds=inputs_embeds,
max_new_tokens=5,
return_dict_in_generate=True,
output_scores=True,
)
self.assertListEqual(outputs_from_ids.sequences.tolist(), outputs_from_embeds.sequences.tolist())
# But if we pass different inputs_embeds, we should get different outputs (the output text may be the
# same, but the logits will almost surely be different)
random_embeds = torch.rand_like(inputs_embeds)
outputs_from_rand_embeds = model.generate(
input_ids,
inputs_embeds=random_embeds,
max_new_tokens=5,
return_dict_in_generate=True,
output_scores=True,
)
for i in range(len(outputs_from_rand_embeds.scores)):
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
# We need to override as we need to prepare such that the image token is the last token
def test_resize_tokens_embeddings(self):
(original_config, inputs_dict) = self.model_tester.prepare_config_and_inputs_for_common()