IDEFICS: support inputs embeds (#34043)
* support embeds * use cache from config * style... * fix tests after rebase
This commit is contained in:
committed by
GitHub
parent
9d6998c759
commit
d087165db0
@@ -3000,8 +3000,11 @@ class ModelTesterMixin:
|
||||
|
||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class.__name__ not in get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES):
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class.__name__ not in [
|
||||
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
|
||||
*get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES),
|
||||
]:
|
||||
continue
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
@@ -3018,6 +3021,13 @@ class ModelTesterMixin:
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1
|
||||
|
||||
# VLMs can't generate with embeds and pixels at the same time. We expect the user to pass merged
|
||||
# embeds already
|
||||
if model_class.__name__ in get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES):
|
||||
inputs.pop("pixel_values", None)
|
||||
inputs.pop("pixel_values_videos", None)
|
||||
inputs.pop("pixel_values_images", None)
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
if not self.is_encoder_decoder:
|
||||
input_ids = inputs["input_ids"]
|
||||
|
||||
Reference in New Issue
Block a user