[VLMs] support passing embeds along with pixels (#38467)
* VLMs can work with embeds now * update more models * fix tests * fix copies * fixup * fix * style * unskip tests * fix copies * fix tests * style * omni modality models * qwen models had extra indentation * fix some other tests * fix copies * fix test last time * unrelated changes revert * we can't rely only on embeds * delete file * de-flake mistral3 * fix qwen models * fix style * fix tests * fix copies * deflake the test * modular reverted by fixes, fix again * flaky test, overwritten * fix copies * style
This commit is contained in:
committed by
GitHub
parent
20901f1d68
commit
f8b88866f5
@@ -457,14 +457,6 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
|
||||
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
||||
|
||||
@pytest.mark.generate
|
||||
@parameterized.expand([("greedy", 1), ("beam search", 2)])
|
||||
@unittest.skip(
|
||||
"KOSMOS-2 doesn't support inputs embeds. The test isn't skipped by checking input args because KOSMOS-2 has `generate()` overwritten"
|
||||
)
|
||||
def test_generate_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
||||
@require_torch_sdpa
|
||||
@unittest.skip("KOSMOS-2 doesn't support padding")
|
||||
@@ -613,6 +605,53 @@ class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
# (Even with this call, there are still memory leak by ~0.04MB)
|
||||
self.clear_torch_jit_class_registry()
|
||||
|
||||
@pytest.mark.generate
|
||||
@parameterized.expand([("greedy", 1), ("beam search", 2)])
|
||||
def test_generate_from_inputs_embeds(self, _, num_beams):
|
||||
"""Tests that we can generate from `inputs_embeds` instead of `input_ids` in LLMs, VLMs, etc"""
|
||||
# NOTE: overwritten because Kosmos with ids prepares position ids differently from embeds
|
||||
# If the model get ids, all pad tokens are masked from position ids. That is not possible with embeds
|
||||
|
||||
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
|
||||
# if fails, you should probably update the `prepare_inputs_for_generation` function
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
config.is_decoder = True
|
||||
|
||||
# Skip models without explicit support
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
# Traditional way of generating text
|
||||
input_ids = inputs_dict.pop("input_ids")
|
||||
input_ids[input_ids == config.get_text_config().pad_token_id] = 0
|
||||
generation_kwargs = {
|
||||
"return_dict_in_generate": True,
|
||||
"output_scores": True,
|
||||
"num_beams": num_beams,
|
||||
"do_sample": False,
|
||||
"max_new_tokens": 5,
|
||||
"min_new_tokens": 5, # generate exactly 5 tokens
|
||||
"use_cache": True,
|
||||
}
|
||||
outputs_from_ids = model.generate(input_ids=input_ids, **generation_kwargs, **inputs_dict)
|
||||
self.assertEqual(outputs_from_ids.sequences.shape[:2], (input_ids.shape[0], input_ids.shape[1] + 5))
|
||||
|
||||
# Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output).
|
||||
# The output of the two calls should be the same.
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
outputs_from_embeds = model.generate(
|
||||
input_ids=input_ids, inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
||||
)
|
||||
self._check_similar_generate_outputs(outputs_from_ids, outputs_from_embeds)
|
||||
|
||||
# input_ids is not a required input on most models -- if we don't pass it, the newly generated tokens will
|
||||
# be the same
|
||||
outputs_from_embeds_wo_ids = model.generate(
|
||||
inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict
|
||||
)
|
||||
outputs_from_embeds.sequences = outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :]
|
||||
self._check_similar_generate_outputs(outputs_from_embeds_wo_ids, outputs_from_embeds)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
|
||||
Reference in New Issue
Block a user