IDEFICS: support inputs embeds (#34043)
* support embeds * use cache from config * style... * fix tests after rebase
This commit is contained in:
committed by
GitHub
parent
9d6998c759
commit
d087165db0
@@ -772,6 +772,12 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni
|
||||
def test_custom_4d_attention_mask(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="IDEFICS has specific requirements for working with inputs embeds like passing also the ids and pixels"
|
||||
)
|
||||
def test_generate_from_inputs_embeds_decoder_only(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
|
||||
def test_generate_compile_fullgraph(self):
|
||||
pass
|
||||
|
||||
@@ -539,6 +539,31 @@ class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
||||
# overwrite because IDEFICS needs ids and embeds at the input to be not None
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
# some models infer position ids/attn mask differently when input ids
|
||||
# by check if pad_token let's make sure no padding is in input ids
|
||||
not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1
|
||||
input_ids[input_ids == pad_token_id] = not_pad_token_id
|
||||
del inputs["input_ids"]
|
||||
inputs_embeds = wte(input_ids)
|
||||
out_ids = model.generate(input_ids=input_ids, **inputs, max_new_tokens=2)
|
||||
out_embeds = model.generate(input_ids=input_ids, inputs_embeds=inputs_embeds, **inputs, max_new_tokens=2)
|
||||
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
|
||||
@require_torch
|
||||
class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
@@ -526,6 +526,31 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
||||
# overwrite because IDEFICS needs ids and embeds at the input to be not None
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
# some models infer position ids/attn mask differently when input ids
|
||||
# by check if pad_token let's make sure no padding is in input ids
|
||||
not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1
|
||||
input_ids[input_ids == pad_token_id] = not_pad_token_id
|
||||
del inputs["input_ids"]
|
||||
inputs_embeds = wte(input_ids)
|
||||
out_ids = model.generate(input_ids=input_ids, **inputs, max_new_tokens=2)
|
||||
out_embeds = model.generate(input_ids=input_ids, inputs_embeds=inputs_embeds, **inputs, max_new_tokens=2)
|
||||
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
|
||||
@require_torch
|
||||
class Idefics3ForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
@@ -428,6 +428,12 @@ class Kosmos2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
|
||||
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
||||
|
||||
@unittest.skip(
|
||||
"KOSMOS-2 doesn't support inputs embeds. The test isn't skipped by checking ipnut args because KOSMOS-2 has `generate()` overwritten"
|
||||
)
|
||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model_name = "microsoft/kosmos-2-patch14-224"
|
||||
|
||||
Reference in New Issue
Block a user