[VLMs] support passing embeds along with pixels (#38467)

* VLMs can work with embeds now

* update more models

* fix tests

* fix copies

* fixup

* fix

* style

* unskip tests

* fix copies

* fix tests

* style

* omni modality models

* qwen models had extra indentation

* fix some other tests

* fix copies

* fix test last time

* unrelated changes revert

* we can't rely only on embeds

* delete file

* de-flake mistral3

* fix qwen models

* fix style

* fix tests

* fix copies

* deflake the test

* modular reverted by fixes, fix again

* flaky test, overwritten

* fix copies

* style
This commit is contained in:
Raushan Turganbay
2025-07-01 13:33:20 +02:00
committed by GitHub
parent 20901f1d68
commit f8b88866f5
78 changed files with 1131 additions and 1705 deletions

View File

@@ -63,9 +63,6 @@ class InternVLVisionText2TextModelTester:
image_seq_length=64,
vision_feature_layer=-1,
ignore_index=-100,
bos_token_id=0,
eos_token_id=0,
pad_token_id=0,
image_token_id=1,
num_channels=3,
image_size=64,
@@ -85,9 +82,9 @@ class InternVLVisionText2TextModelTester:
"rope_theta": 10000,
"mlp_ratio": 4,
"tie_word_embeddings": True,
"bos_token_id": 0,
"eos_token_id": 0,
"pad_token_id": 0,
"bos_token_id": 3,
"eos_token_id": 4,
"pad_token_id": 5,
},
vision_config={
"hidden_size": 32,
@@ -103,9 +100,9 @@ class InternVLVisionText2TextModelTester:
):
self.parent = parent
self.ignore_index = ignore_index
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.bos_token_id = text_config["bos_token_id"]
self.eos_token_id = text_config["eos_token_id"]
self.pad_token_id = text_config["pad_token_id"]
self.image_token_id = image_token_id
self.model_type = model_type
self.text_config = text_config
@@ -128,9 +125,6 @@ class InternVLVisionText2TextModelTester:
text_config=self.text_config,
vision_config=self.vision_config,
model_type=self.model_type,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
image_token_id=self.image_token_id,
image_seq_length=self.image_seq_length,
vision_feature_layer=self.vision_feature_layer,
@@ -148,7 +142,6 @@ class InternVLVisionText2TextModelTester:
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
# input_ids[:, -1] = self.pad_token_id
input_ids[input_ids == self.image_token_id] = self.pad_token_id
input_ids[:, : self.image_seq_length] = self.image_token_id
@@ -222,49 +215,6 @@ class InternVLModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
@unittest.skip(reason="Compile not yet supported because in LLava models")
def test_sdpa_can_compile_dynamic(self):
pass