IDEFICS: support inputs embeds (#34043)

* support embeds

* use cache from config

* style...

* fix tests after rebase
This commit is contained in:
Raushan Turganbay
2024-10-16 09:25:26 +02:00
committed by GitHub
parent 9d6998c759
commit d087165db0
7 changed files with 100 additions and 18 deletions

View File

@@ -526,6 +526,31 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
def test_inputs_embeds_matches_input_ids_with_generate(self):
# overwrite because IDEFICS needs ids and embeds at the input to be not None
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1
wte = model.get_input_embeddings()
input_ids = inputs["input_ids"]
# some models infer position ids/attn mask differently when input ids
# by check if pad_token let's make sure no padding is in input ids
not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1
input_ids[input_ids == pad_token_id] = not_pad_token_id
del inputs["input_ids"]
inputs_embeds = wte(input_ids)
out_ids = model.generate(input_ids=input_ids, **inputs, max_new_tokens=2)
out_embeds = model.generate(input_ids=input_ids, inputs_embeds=inputs_embeds, **inputs, max_new_tokens=2)
self.assertTrue(torch.allclose(out_embeds, out_ids))
@require_torch
class Idefics3ForConditionalGenerationIntegrationTest(unittest.TestCase):