[VLMs] support passing embeds along with pixels (#38467)
* VLMs can work with embeds now * update more models * fix tests * fix copies * fixup * fix * style * unskip tests * fix copies * fix tests * style * omni modality models * qwen models had extra indentation * fix some other tests * fix copies * fix test last time * unrelated changes revert * we can't rely only on embeds * delete file * de-flake mistral3 * fix qwen models * fix style * fix tests * fix copies * deflake the test * modular reverted by fixes, fix again * flaky test, overwritten * fix copies * style
This commit is contained in:
committed by
GitHub
parent
20901f1d68
commit
f8b88866f5
@@ -51,9 +51,6 @@ class GotOcr2VisionText2TextModelTester:
|
||||
num_channels=3,
|
||||
ignore_index=-100,
|
||||
image_size=64,
|
||||
bos_token_id=0,
|
||||
eos_token_id=0,
|
||||
pad_token_id=0,
|
||||
image_token_index=1,
|
||||
model_type="got_ocr2",
|
||||
is_training=True,
|
||||
@@ -71,6 +68,9 @@ class GotOcr2VisionText2TextModelTester:
|
||||
"rope_theta": 10000,
|
||||
"mlp_ratio": 4,
|
||||
"tie_word_embeddings": True,
|
||||
"bos_token_id": 2,
|
||||
"eos_token_id": 3,
|
||||
"pad_token_id": 4,
|
||||
},
|
||||
vision_config={
|
||||
"num_hidden_layers": 2,
|
||||
@@ -85,9 +85,9 @@ class GotOcr2VisionText2TextModelTester:
|
||||
):
|
||||
self.parent = parent
|
||||
self.ignore_index = ignore_index
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = text_config["bos_token_id"]
|
||||
self.eos_token_id = text_config["eos_token_id"]
|
||||
self.pad_token_id = text_config["pad_token_id"]
|
||||
self.image_token_index = image_token_index
|
||||
self.model_type = model_type
|
||||
self.text_config = text_config
|
||||
@@ -109,9 +109,6 @@ class GotOcr2VisionText2TextModelTester:
|
||||
text_config=self.text_config,
|
||||
vision_config=self.vision_config,
|
||||
model_type=self.model_type,
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
image_token_index=self.image_token_index,
|
||||
)
|
||||
|
||||
@@ -127,7 +124,6 @@ class GotOcr2VisionText2TextModelTester:
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||
|
||||
# input_ids[:, -1] = self.pad_token_id
|
||||
input_ids[input_ids == self.image_token_index] = self.pad_token_id
|
||||
input_ids[:, : self.num_image_tokens] = self.image_token_index
|
||||
|
||||
@@ -181,55 +177,6 @@ class GotOcr2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
@unittest.skip(
|
||||
reason="VLMs can't generate from inputs embeds and pixels. This can be tested as part of bacbone LM, no need to run the test for VLMs"
|
||||
)
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="GotOcr2's language backbone is Qwen2 which uses GQA so the KV cache is a non standard format"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user