[VLMs] support passing embeds along with pixels (#38467)
* VLMs can work with embeds now * update more models * fix tests * fix copies * fixup * fix * style * unskip tests * fix copies * fix tests * style * omni modality models * qwen models had extra indentation * fix some other tests * fix copies * fix test last time * unrelated changes revert * we can't rely only on embeds * delete file * de-flake mistral3 * fix qwen models * fix style * fix tests * fix copies * deflake the test * modular reverted by fixes, fix again * flaky test, overwritten * fix copies * style
This commit is contained in:
committed by
GitHub
parent
20901f1d68
commit
f8b88866f5
@@ -189,49 +189,6 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi
|
||||
self.model_tester = AriaVisionText2TextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=AriaConfig, has_text_modality=False)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
@@ -270,14 +227,6 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi
|
||||
def test_dola_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Unsupported")
|
||||
def test_generate_from_inputs_embeds_0_greedy(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Unsupported")
|
||||
def test_generate_from_inputs_embeds_1_beam_search(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Dynamic control flow due to MoE")
|
||||
def test_generate_with_static_cache(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user