🔴 [VLM] Add base model without head (#37033)
* i guessreverted all CdGen classes * style * llava onevision * fix copies * fix some tests * some more tests * dump * skip these * nevermind, i am dumb * revert fix not needed * fixup * fixup * another fixup * more fixup to make ci finally happy * fixup after rebasing * fix qwen tests * add internVL + typos here and there * image token index -> id * style * fix init weights * revert blip-2 not supported * address comments * fix copies * revert blip2 test file as well * as discussed internally, revert back CdGen models * fix some tests * fix more tests for compile * CI red * fix copies * enumerate explicitly allowed models * address comments * fix tests * fixup * style again * add tests for new model class * another fixup ( x _ x ) * [fixup] unused attributes can be removed post-deprecation
This commit is contained in:
committed by
GitHub
parent
3fa8d9c20e
commit
17742bd9c8
@@ -25,6 +25,7 @@ from transformers import (
|
||||
MllamaConfig,
|
||||
MllamaForCausalLM,
|
||||
MllamaForConditionalGeneration,
|
||||
MllamaModel,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
@@ -262,7 +263,14 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
|
||||
Model tester for `MllamaForConditionalGeneration`.
|
||||
"""
|
||||
|
||||
all_model_classes = (MllamaForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(
|
||||
MllamaModel,
|
||||
MllamaForConditionalGeneration,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = {"image-text-to-text": MllamaForConditionalGeneration} if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
@@ -325,19 +333,18 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
|
||||
# resizing embeddings should result in successful loss computation
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model_vocab_size = config.get_text_config().vocab_size
|
||||
inputs = self._prepare_for_class(inputs, model_class, return_labels=True)
|
||||
# Resize embeddings and call forward
|
||||
model.resize_token_embeddings(model_vocab_size + 10)
|
||||
output = model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
labels=inputs["labels"],
|
||||
return_dict=True,
|
||||
)
|
||||
self.assertTrue("loss" in output)
|
||||
model = MllamaForConditionalGeneration(config).to(torch_device)
|
||||
model_vocab_size = config.get_text_config().vocab_size
|
||||
inputs = self._prepare_for_class(inputs, MllamaForConditionalGeneration, return_labels=True)
|
||||
# Resize embeddings and call forward
|
||||
model.resize_token_embeddings(model_vocab_size + 10)
|
||||
output = model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
labels=inputs["labels"],
|
||||
return_dict=True,
|
||||
)
|
||||
self.assertTrue("loss" in output)
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
|
||||
@@ -409,6 +416,18 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
|
||||
def test_assisted_decoding_with_num_logits_to_keep(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Mllama uses self.weights dirrectly causing device mismatch when offloading`")
|
||||
def test_cpu_offload(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Mllama uses self.weights dirrectly causing device mismatch when offloading`")
|
||||
def test_disk_offload_bin(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Mllama uses self.weights dirrectly causing device mismatch when offloading`")
|
||||
def test_disk_offload_safetensors(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
# overridden because mllama is not an encoder-decoder model, but has encoder-decoder-like cache
|
||||
def test_past_key_values_format(self):
|
||||
@@ -501,7 +520,7 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
|
||||
"""
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
Reference in New Issue
Block a user