[VLMs] support passing embeds along with pixels (#38467)
* VLMs can work with embeds now * update more models * fix tests * fix copies * fixup * fix * style * unskip tests * fix copies * fix tests * style * omni modality models * qwen models had extra indentation * fix some other tests * fix copies * fix test last time * unrelated changes revert * we can't rely only on embeds * delete file * de-flake mistral3 * fix qwen models * fix style * fix tests * fix copies * deflake the test * modular reverted by fixes, fix again * flaky test, overwritten * fix copies * style
This commit is contained in:
committed by
GitHub
parent
20901f1d68
commit
f8b88866f5
@@ -118,27 +118,6 @@ from unittest.mock import patch
|
||||
from transformers.utils import is_sklearn_available
|
||||
|
||||
|
||||
# TODO: raushan remove this when VLMs start accepting input embeds
|
||||
VLM_CLASS_NAMES = [
|
||||
"llava",
|
||||
"idefics2",
|
||||
"idefics3",
|
||||
"mllama",
|
||||
"paligemma",
|
||||
"emu3",
|
||||
"gotocr2",
|
||||
"qwen2vl",
|
||||
"qwen2_5_vl",
|
||||
"ayavision",
|
||||
"janus",
|
||||
"gemma3",
|
||||
"mistral3",
|
||||
"chameleon",
|
||||
"internvl",
|
||||
"qwen2_5omni", # the file is named `qwen2_5_omni`, but the model class is `Qwen2_5Omni`
|
||||
]
|
||||
|
||||
|
||||
class GenerationTesterMixin:
|
||||
input_name = "input_ids"
|
||||
model_tester = None
|
||||
@@ -1228,7 +1207,23 @@ class GenerationTesterMixin:
|
||||
"blip2", # overridden `generate()`
|
||||
"instructblip",
|
||||
"instructblipvideo",
|
||||
*VLM_CLASS_NAMES, # shouldn't suggest image tokens
|
||||
# All models below: shouldn't suggest image tokens. Can be fixed by passing `suppress_ids` to candidate generator: @joaa @raushan
|
||||
"llava",
|
||||
"idefics2",
|
||||
"idefics3",
|
||||
"mllama",
|
||||
"paligemma",
|
||||
"emu3",
|
||||
"gotocr2",
|
||||
"qwen2vl",
|
||||
"qwen2_5_vl",
|
||||
"ayavision",
|
||||
"janus",
|
||||
"gemma3",
|
||||
"mistral3",
|
||||
"chameleon",
|
||||
"internvl",
|
||||
"qwen2_5omni", # the file is named `qwen2_5_omni`, but the model class is `Qwen2_5Omni`,
|
||||
]
|
||||
):
|
||||
self.skipTest(reason="May fix in the future: need model-specific fixes")
|
||||
@@ -1641,6 +1636,58 @@ class GenerationTesterMixin:
|
||||
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
|
||||
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_from_random_inputs_embeds(self):
|
||||
"""
|
||||
Text-only: Tests that different `inputs_embeds` generate different outputs in models with `main_input=="input_ids"`.
|
||||
Some models have 'images' as main input and thus can't generate with random text embeddings.
|
||||
See `test_generate_from_inputs_embeds` for more general checks.
|
||||
"""
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
|
||||
if config.is_encoder_decoder:
|
||||
continue
|
||||
config.is_decoder = True
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
|
||||
continue
|
||||
|
||||
# No easy fix, let's skip the test for now
|
||||
has_complex_embeds_computation = any(
|
||||
model_name in model_class.__name__.lower() for model_name in ["moshi"]
|
||||
)
|
||||
|
||||
if model_class.main_input_name != "input_ids" or has_complex_embeds_computation:
|
||||
self.skipTest(
|
||||
"The model's main input name in not `input_ids` and we need kwargs from input dict as well."
|
||||
)
|
||||
|
||||
if hasattr(config, "scale_embedding"):
|
||||
config.scale_embedding = False
|
||||
|
||||
generation_kwargs = {
|
||||
"return_dict_in_generate": True,
|
||||
"output_scores": True,
|
||||
"do_sample": False,
|
||||
"max_new_tokens": 5,
|
||||
"min_new_tokens": 5, # generate exactly 5 tokens
|
||||
}
|
||||
|
||||
input_ids = inputs_dict.pop("input_ids")
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds, **generation_kwargs)
|
||||
|
||||
# If we pass different inputs_embeds, we should get different outputs (the output text may be the
|
||||
# same, but the logits will almost surely be different)
|
||||
random_embeds = torch.rand_like(inputs_embeds)
|
||||
outputs_from_rand_embeds = model.generate(
|
||||
input_ids=input_ids, inputs_embeds=random_embeds, **generation_kwargs
|
||||
)
|
||||
for i in range(len(outputs_from_rand_embeds.scores)):
|
||||
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
|
||||
|
||||
@pytest.mark.generate
|
||||
@parameterized.expand([("greedy", 1), ("beam search", 2)])
|
||||
def test_generate_from_inputs_embeds(self, _, num_beams):
|
||||
@@ -1662,34 +1709,22 @@ class GenerationTesterMixin:
|
||||
continue
|
||||
|
||||
# There are a few exception patterns in this test:
|
||||
# 1 - Some models can't generate without `input_ids`, when `inputs_embeds` are passed
|
||||
requires_inputs_ids = any(model_name in model_class.__name__.lower() for model_name in ["idefics"])
|
||||
# 2 - Complex `inputs_embeds` computation, i.e. the correct computation of inputs embeds is more complex
|
||||
# 1 - Complex `inputs_embeds` computation, i.e. the correct computation of inputs embeds is more complex
|
||||
# than calling the embedding layer with `input_ids`. Subcases of this exception:
|
||||
# 2.A - Ignore `scale_embedding`, if the model supports it (it is controlled by a model-dependent flag)
|
||||
# 1.A - Ignore `scale_embedding`, if the model supports it (it is controlled by a model-dependent flag)
|
||||
if hasattr(config, "scale_embedding"):
|
||||
config.scale_embedding = False
|
||||
# 2.B - Some VLMs assume `inputs_embeds` and `pixel_values` are mutually exclusive AND fall in the
|
||||
# exception above (complex `inputs_embeds` computation). Popping `pixel_values` allow us to run the
|
||||
# checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images`
|
||||
pixel_values_is_mutually_exclusive = any(
|
||||
model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES
|
||||
)
|
||||
if pixel_values_is_mutually_exclusive:
|
||||
inputs_dict.pop("pixel_values", None)
|
||||
inputs_dict.pop("pixel_values_videos", None)
|
||||
inputs_dict.pop("pixel_values_images", None)
|
||||
# HACK - in the case of granite speech, input_features and inputs_embeds are mutually exclusive;
|
||||
# this is similar to VLMs and should likely be standardized for similar audio models in the future,
|
||||
# then made generic here.
|
||||
if "granitespeech" in model_class.__name__.lower():
|
||||
inputs_dict.pop("input_features", None)
|
||||
|
||||
# 2.C - No easy fix, let's skip the check that compares the outputs from `input_ids` and `inputs_embeds`
|
||||
# 1.B - No easy fix, let's skip the check that compares the outputs from `input_ids` and `inputs_embeds`
|
||||
has_complex_embeds_computation = any(
|
||||
model_name in model_class.__name__.lower() for model_name in ["moshi"]
|
||||
)
|
||||
# 3 - `inputs_dict` doesn't contain `attention_mask`. When `attention_mask` is not passed to generate,
|
||||
# 2 - `inputs_dict` doesn't contain `attention_mask`. When `attention_mask` is not passed to generate,
|
||||
# we infer it from `input_ids`. The last test case will fail if there is a pad token in the original input.
|
||||
missing_attention_mask = "attention_mask" not in inputs_dict
|
||||
|
||||
@@ -1702,31 +1737,23 @@ class GenerationTesterMixin:
|
||||
"do_sample": False,
|
||||
"max_new_tokens": 5,
|
||||
"min_new_tokens": 5, # generate exactly 5 tokens
|
||||
"use_cache": True,
|
||||
}
|
||||
outputs_from_ids = model.generate(input_ids, **generation_kwargs, **inputs_dict)
|
||||
outputs_from_ids = model.generate(input_ids=input_ids, **generation_kwargs, **inputs_dict)
|
||||
self.assertEqual(outputs_from_ids.sequences.shape[:2], (input_ids.shape[0], input_ids.shape[1] + 5))
|
||||
|
||||
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output).
|
||||
# The output of the two calls should be the same.
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
outputs_from_embeds = model.generate(
|
||||
input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
||||
input_ids=input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
||||
)
|
||||
if not has_complex_embeds_computation:
|
||||
self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds)
|
||||
|
||||
# If we pass different inputs_embeds, we should get different outputs (the output text may be the
|
||||
# same, but the logits will almost surely be different)
|
||||
random_embeds = torch.rand_like(inputs_embeds)
|
||||
outputs_from_rand_embeds = model.generate(
|
||||
input_ids, inputs_embeds=random_embeds, **generation_kwargs, **inputs_dict
|
||||
)
|
||||
for i in range(len(outputs_from_rand_embeds.scores)):
|
||||
self.assertFalse(torch.allclose(outputs_from_embeds.scores[i], outputs_from_rand_embeds.scores[i]))
|
||||
|
||||
# input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will
|
||||
# be the same
|
||||
if not (requires_inputs_ids or missing_attention_mask):
|
||||
if not missing_attention_mask:
|
||||
outputs_from_embeds_wo_ids = model.generate(
|
||||
inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
||||
)
|
||||
@@ -1753,17 +1780,6 @@ class GenerationTesterMixin:
|
||||
if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
|
||||
self.skipTest(reason="This model does not support `inputs_embeds` in generation")
|
||||
|
||||
# Some VLMs assume `inputs_embeds` and `pixel_values` are mutually exclusive AND fall in the
|
||||
# exception above (complex `inputs_embeds` computation). Popping `pixel_values` allow us to run the
|
||||
# checks without adding test complexity. Ditto for `pixel_values_videos` and `pixel_values_images`
|
||||
pixel_values_is_mutually_exclusive = any(
|
||||
model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES
|
||||
)
|
||||
if pixel_values_is_mutually_exclusive:
|
||||
inputs_dict.pop("pixel_values", None)
|
||||
inputs_dict.pop("pixel_values_videos", None)
|
||||
inputs_dict.pop("pixel_values_images", None)
|
||||
|
||||
input_ids = inputs_dict.pop("input_ids")
|
||||
|
||||
model.config.use_cache = True
|
||||
@@ -1925,14 +1941,6 @@ class GenerationTesterMixin:
|
||||
if "past_key_values" not in outputs:
|
||||
self.skipTest(reason="This model doesn't return `past_key_values`")
|
||||
|
||||
pixel_values_is_mutually_exclusive = any(
|
||||
model_name in model_class.__name__.lower() for model_name in VLM_CLASS_NAMES
|
||||
)
|
||||
if pixel_values_is_mutually_exclusive:
|
||||
inputs_dict.pop("pixel_values", None)
|
||||
inputs_dict.pop("pixel_values_videos", None)
|
||||
inputs_dict.pop("pixel_values_images", None)
|
||||
|
||||
input_ids = inputs_dict.pop("input_ids")
|
||||
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
|
||||
|
||||
Reference in New Issue
Block a user