IDEFICS: support inputs embeds (#34043)

* support embeds

* use cache from config

* style...

* fix tests after rebase
This commit is contained in:
Raushan Turganbay
2024-10-16 09:25:26 +02:00
committed by GitHub
parent 9d6998c759
commit d087165db0
7 changed files with 100 additions and 18 deletions

View File

@@ -3000,8 +3000,11 @@ class ModelTesterMixin:
def test_inputs_embeds_matches_input_ids_with_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_generative_model_classes:
if model_class.__name__ not in get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES):
for model_class in self.all_model_classes:
if model_class.__name__ not in [
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
*get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES),
]:
continue
model = model_class(config)
model.to(torch_device)
@@ -3018,6 +3021,13 @@ class ModelTesterMixin:
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1
# VLMs can't generate with embeds and pixels at the same time. We expect the user to pass merged
# embeds already
if model_class.__name__ in get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES):
inputs.pop("pixel_values", None)
inputs.pop("pixel_values_videos", None)
inputs.pop("pixel_values_images", None)
wte = model.get_input_embeddings()
if not self.is_encoder_decoder:
input_ids = inputs["input_ids"]