Blip: get/set input embeddings correctly (#34152)
* set-get embeds * add tests * fix tests * remove * return dict True * fix tests * why did i remove this * enabel torchscript tests
This commit is contained in:
committed by
GitHub
parent
b53e44e847
commit
6beb3f1691
@@ -795,6 +795,12 @@ class BlipModel(BlipPreTrainedModel):
|
|||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.text_model.get_input_embeddings()
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.text_model.set_input_embeddings(value)
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BLIP_TEXT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(BLIP_TEXT_INPUTS_DOCSTRING)
|
||||||
def get_text_features(
|
def get_text_features(
|
||||||
self,
|
self,
|
||||||
@@ -1053,8 +1059,11 @@ class BlipForConditionalGeneration(BlipPreTrainedModel, GenerationMixin):
|
|||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
def get_input_embeddings(self) -> nn.Module:
|
def get_input_embeddings(self):
|
||||||
return self.vision_model.embeddings.patch_embedding
|
return self.text_decoder.get_input_embeddings()
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.text_decoder.set_input_embeddings(value)
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=BlipForConditionalGenerationModelOutput, config_class=BlipVisionConfig)
|
@replace_return_docstrings(output_type=BlipForConditionalGenerationModelOutput, config_class=BlipVisionConfig)
|
||||||
@@ -1117,7 +1126,8 @@ class BlipForConditionalGeneration(BlipPreTrainedModel, GenerationMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
outputs = (outputs[0], outputs[1], image_embeds, vision_outputs[0]) + vision_outputs[2:]
|
outputs = (outputs[0], outputs[1]) if labels is not None else (outputs[0],)
|
||||||
|
outputs += (image_embeds, vision_outputs[0]) + vision_outputs[2:]
|
||||||
return tuple(output for output in outputs if output is not None)
|
return tuple(output for output in outputs if output is not None)
|
||||||
|
|
||||||
return BlipForConditionalGenerationModelOutput(
|
return BlipForConditionalGenerationModelOutput(
|
||||||
@@ -1232,8 +1242,12 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
|
|||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
def get_input_embeddings(self) -> nn.Module:
|
def set_input_embeddings(self, value):
|
||||||
return self.vision_model.embeddings.patch_embedding
|
self.text_encoder.set_input_embeddings(value)
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
# This will return shared embeddings if they are shared else specific to encoder.
|
||||||
|
return self.text_encoder.get_input_embeddings()
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
|
@replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
|
||||||
@@ -1474,8 +1488,11 @@ class BlipForImageTextRetrieval(BlipPreTrainedModel):
|
|||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
def get_input_embeddings(self) -> nn.Module:
|
def get_input_embeddings(self):
|
||||||
return self.vision_model.embeddings.patch_embedding
|
return self.text_encoder.get_input_embeddings()
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.text_encoder.set_input_embeddings(value)
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
|
@replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
|
||||||
|
|||||||
@@ -817,6 +817,12 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel, GenerationMixin):
|
|||||||
self.cls = BlipTextOnlyMLMHead(config)
|
self.cls = BlipTextOnlyMLMHead(config)
|
||||||
self.label_smoothing = config.label_smoothing
|
self.label_smoothing = config.label_smoothing
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.bert.get_input_embeddings()
|
||||||
|
|
||||||
|
def set_input_embeddings(self, new_embeddings):
|
||||||
|
self.bert.set_input_embeddings(new_embeddings)
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.cls.predictions.decoder
|
return self.cls.predictions.decoder
|
||||||
|
|
||||||
|
|||||||
@@ -1768,11 +1768,12 @@ class Blip2Model(Blip2PreTrainedModel):
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=True, # toggle for easier access to loss/logits below
|
||||||
labels=labels,
|
labels=labels,
|
||||||
)
|
)
|
||||||
loss = outputs.loss if return_dict else outputs[0]
|
loss = outputs.loss
|
||||||
logits = outputs.logits if return_dict else outputs[1]
|
logits = outputs.logits
|
||||||
|
outputs = outputs.to_tuple() if not return_dict else outputs
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits, vision_outputs, query_outputs, outputs)
|
output = (logits, vision_outputs, query_outputs, outputs)
|
||||||
@@ -1810,6 +1811,12 @@ class Blip2TextModelWithProjection(Blip2PreTrainedModel):
|
|||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embeddings.word_embeddings
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embeddings.word_embeddings = value
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BLIP_2_TEXT_WITH_PROJECTION_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(BLIP_2_TEXT_WITH_PROJECTION_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=Blip2TextModelOutput, config_class=Blip2Config)
|
@replace_return_docstrings(output_type=Blip2TextModelOutput, config_class=Blip2Config)
|
||||||
def forward(
|
def forward(
|
||||||
@@ -2233,11 +2240,12 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=True, # toggle for easier access to loss/logits below
|
||||||
labels=labels,
|
labels=labels,
|
||||||
)
|
)
|
||||||
loss = outputs.loss if return_dict else outputs[0]
|
loss = outputs.loss
|
||||||
logits = outputs.logits if return_dict else outputs[1]
|
logits = outputs.logits
|
||||||
|
outputs = outputs.to_tuple() if not return_dict else outputs
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits, vision_outputs, query_outputs, outputs)
|
output = (logits, vision_outputs, query_outputs, outputs)
|
||||||
@@ -2389,6 +2397,12 @@ class Blip2ForImageTextRetrieval(Blip2PreTrainedModel):
|
|||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embeddings.word_embeddings
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embeddings.word_embeddings = value
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BLIP2_IMAGE_TEXT_RETRIEVAL_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(BLIP2_IMAGE_TEXT_RETRIEVAL_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=Blip2ImageTextMatchingModelOutput, config_class=Blip2Config)
|
@replace_return_docstrings(output_type=Blip2ImageTextMatchingModelOutput, config_class=Blip2Config)
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -444,7 +444,7 @@ class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = True
|
||||||
test_attention_outputs = False
|
test_attention_outputs = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@@ -738,7 +738,6 @@ class BlipTextImageModelsModelTester:
|
|||||||
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
||||||
inputs_dict = {
|
inputs_dict = {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"labels": input_ids,
|
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
}
|
}
|
||||||
@@ -787,10 +786,10 @@ class BlipVQAModelTester:
|
|||||||
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
||||||
inputs_dict = {
|
inputs_dict = {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"labels": input_ids,
|
|
||||||
"decoder_input_ids": input_ids,
|
"decoder_input_ids": input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
|
"labels": input_ids,
|
||||||
}
|
}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
@@ -802,7 +801,7 @@ class BlipVQAModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = True
|
||||||
test_attention_outputs = False
|
test_attention_outputs = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
|
|
||||||
@@ -811,7 +810,6 @@ class BlipVQAModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def _prepare_inputs_for_vqa(self):
|
def _prepare_inputs_for_vqa(self):
|
||||||
_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
inputs_dict["labels"] = inputs_dict["input_ids"]
|
|
||||||
inputs_dict["decoder_input_ids"] = inputs_dict["input_ids"]
|
inputs_dict["decoder_input_ids"] = inputs_dict["input_ids"]
|
||||||
inputs_dict.pop("return_loss")
|
inputs_dict.pop("return_loss")
|
||||||
return inputs_dict
|
return inputs_dict
|
||||||
@@ -882,7 +880,7 @@ class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = True
|
||||||
test_attention_outputs = False
|
test_attention_outputs = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
|
|
||||||
@@ -1110,7 +1108,7 @@ class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = True
|
||||||
test_attention_outputs = False
|
test_attention_outputs = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
"""Testing suite for the PyTorch BLIP-2 model."""
|
"""Testing suite for the PyTorch BLIP-2 model."""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@@ -32,7 +33,7 @@ from transformers.testing_utils import (
|
|||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
from transformers.utils import is_torch_available, is_vision_available
|
from transformers.utils import is_torch_available, is_torch_sdpa_available, is_vision_available
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@@ -443,7 +444,6 @@ class Blip2ForConditionalGenerationDecoderOnlyModelTester:
|
|||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"labels": input_ids,
|
|
||||||
}
|
}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
@@ -456,7 +456,7 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
|
|||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = False
|
||||||
test_attention_outputs = False
|
test_attention_outputs = False
|
||||||
test_torchscript = False
|
test_torchscript = True
|
||||||
_is_composite = True
|
_is_composite = True
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@@ -466,6 +466,116 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_conditional_generation(*config_and_inputs)
|
self.model_tester.create_and_check_for_conditional_generation(*config_and_inputs)
|
||||||
|
|
||||||
|
def _create_and_check_torchscript(self, config, inputs_dict):
|
||||||
|
# overwrite because BLIP requires ipnut ids and pixel values as input
|
||||||
|
if not self.test_torchscript:
|
||||||
|
self.skipTest(reason="test_torchscript is set to `False`")
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||||
|
configs_no_init.torchscript = True
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
for attn_implementation in ["eager", "sdpa"]:
|
||||||
|
if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()):
|
||||||
|
continue
|
||||||
|
|
||||||
|
configs_no_init._attn_implementation = attn_implementation
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
|
||||||
|
main_input_name = model_class.main_input_name
|
||||||
|
|
||||||
|
try:
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||||
|
main_input = inputs[main_input_name]
|
||||||
|
input_ids = inputs["input_ids"]
|
||||||
|
attention_mask = inputs["attention_mask"]
|
||||||
|
decoder_input_ids = inputs["decoder_input_ids"]
|
||||||
|
decoder_attention_mask = inputs["decoder_attention_mask"]
|
||||||
|
model(main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||||
|
traced_model = torch.jit.trace(
|
||||||
|
model, (main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
main_input = inputs[main_input_name]
|
||||||
|
input_ids = inputs["input_ids"]
|
||||||
|
|
||||||
|
if model.config._attn_implementation == "sdpa":
|
||||||
|
trace_input = {main_input_name: main_input, "input_ids": input_ids}
|
||||||
|
|
||||||
|
if "attention_mask" in inputs:
|
||||||
|
trace_input["attention_mask"] = inputs["attention_mask"]
|
||||||
|
else:
|
||||||
|
self.skipTest(reason="testing SDPA without attention_mask is not supported")
|
||||||
|
|
||||||
|
model(main_input, attention_mask=inputs["attention_mask"])
|
||||||
|
# example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
|
||||||
|
traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input)
|
||||||
|
else:
|
||||||
|
model(main_input, input_ids)
|
||||||
|
traced_model = torch.jit.trace(model, (main_input, input_ids))
|
||||||
|
except RuntimeError:
|
||||||
|
self.fail("Couldn't trace module.")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
|
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
||||||
|
|
||||||
|
try:
|
||||||
|
torch.jit.save(traced_model, pt_file_name)
|
||||||
|
except Exception:
|
||||||
|
self.fail("Couldn't save module.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
loaded_model = torch.jit.load(pt_file_name)
|
||||||
|
except Exception:
|
||||||
|
self.fail("Couldn't load module.")
|
||||||
|
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
loaded_model.to(torch_device)
|
||||||
|
loaded_model.eval()
|
||||||
|
|
||||||
|
model_state_dict = model.state_dict()
|
||||||
|
loaded_model_state_dict = loaded_model.state_dict()
|
||||||
|
|
||||||
|
non_persistent_buffers = {}
|
||||||
|
for key in loaded_model_state_dict.keys():
|
||||||
|
if key not in model_state_dict.keys():
|
||||||
|
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
||||||
|
|
||||||
|
loaded_model_state_dict = {
|
||||||
|
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
||||||
|
}
|
||||||
|
|
||||||
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
|
model_buffers = list(model.buffers())
|
||||||
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
|
found_buffer = False
|
||||||
|
for i, model_buffer in enumerate(model_buffers):
|
||||||
|
if torch.equal(non_persistent_buffer, model_buffer):
|
||||||
|
found_buffer = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(found_buffer)
|
||||||
|
model_buffers.pop(i)
|
||||||
|
|
||||||
|
models_equal = True
|
||||||
|
for layer_name, p1 in model_state_dict.items():
|
||||||
|
if layer_name in loaded_model_state_dict:
|
||||||
|
p2 = loaded_model_state_dict[layer_name]
|
||||||
|
if p1.data.ne(p2.data).sum() > 0:
|
||||||
|
models_equal = False
|
||||||
|
|
||||||
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
|
||||||
|
# (Even with this call, there are still memory leak by ~0.04MB)
|
||||||
|
self.clear_torch_jit_class_registry()
|
||||||
|
|
||||||
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||||
def test_hidden_states_output(self):
|
def test_hidden_states_output(self):
|
||||||
pass
|
pass
|
||||||
@@ -754,7 +864,6 @@ class Blip2ModelTester:
|
|||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"decoder_attention_mask": decoder_attention_mask,
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
"labels": labels,
|
|
||||||
}
|
}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
@@ -775,9 +884,9 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
|
|||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = True
|
||||||
test_attention_outputs = False
|
test_attention_outputs = False
|
||||||
test_torchscript = False
|
test_torchscript = True
|
||||||
_is_composite = True
|
_is_composite = True
|
||||||
|
|
||||||
# TODO: Fix the failed tests
|
# TODO: Fix the failed tests
|
||||||
@@ -804,6 +913,116 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_conditional_generation(*config_and_inputs)
|
self.model_tester.create_and_check_for_conditional_generation(*config_and_inputs)
|
||||||
|
|
||||||
|
def _create_and_check_torchscript(self, config, inputs_dict):
|
||||||
|
# overwrite because BLIP requires ipnut ids and pixel values as input
|
||||||
|
if not self.test_torchscript:
|
||||||
|
self.skipTest(reason="test_torchscript is set to `False`")
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||||
|
configs_no_init.torchscript = True
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
for attn_implementation in ["eager", "sdpa"]:
|
||||||
|
if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()):
|
||||||
|
continue
|
||||||
|
|
||||||
|
configs_no_init._attn_implementation = attn_implementation
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
|
||||||
|
main_input_name = model_class.main_input_name
|
||||||
|
|
||||||
|
try:
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
|
||||||
|
main_input = inputs[main_input_name]
|
||||||
|
input_ids = inputs["input_ids"]
|
||||||
|
attention_mask = inputs["attention_mask"]
|
||||||
|
decoder_input_ids = inputs["decoder_input_ids"]
|
||||||
|
decoder_attention_mask = inputs["decoder_attention_mask"]
|
||||||
|
model(main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||||
|
traced_model = torch.jit.trace(
|
||||||
|
model, (main_input, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
main_input = inputs[main_input_name]
|
||||||
|
input_ids = inputs["input_ids"]
|
||||||
|
|
||||||
|
if model.config._attn_implementation == "sdpa":
|
||||||
|
trace_input = {main_input_name: main_input, "input_ids": input_ids}
|
||||||
|
|
||||||
|
if "attention_mask" in inputs:
|
||||||
|
trace_input["attention_mask"] = inputs["attention_mask"]
|
||||||
|
else:
|
||||||
|
self.skipTest(reason="testing SDPA without attention_mask is not supported")
|
||||||
|
|
||||||
|
model(main_input, attention_mask=inputs["attention_mask"])
|
||||||
|
# example_kwarg_inputs was introduced in torch==2.0, but it is fine here since SDPA has a requirement on torch>=2.1.
|
||||||
|
traced_model = torch.jit.trace(model, example_kwarg_inputs=trace_input)
|
||||||
|
else:
|
||||||
|
model(main_input, input_ids)
|
||||||
|
traced_model = torch.jit.trace(model, (main_input, input_ids))
|
||||||
|
except RuntimeError:
|
||||||
|
self.fail("Couldn't trace module.")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
|
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
||||||
|
|
||||||
|
try:
|
||||||
|
torch.jit.save(traced_model, pt_file_name)
|
||||||
|
except Exception:
|
||||||
|
self.fail("Couldn't save module.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
loaded_model = torch.jit.load(pt_file_name)
|
||||||
|
except Exception:
|
||||||
|
self.fail("Couldn't load module.")
|
||||||
|
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
loaded_model.to(torch_device)
|
||||||
|
loaded_model.eval()
|
||||||
|
|
||||||
|
model_state_dict = model.state_dict()
|
||||||
|
loaded_model_state_dict = loaded_model.state_dict()
|
||||||
|
|
||||||
|
non_persistent_buffers = {}
|
||||||
|
for key in loaded_model_state_dict.keys():
|
||||||
|
if key not in model_state_dict.keys():
|
||||||
|
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
||||||
|
|
||||||
|
loaded_model_state_dict = {
|
||||||
|
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
||||||
|
}
|
||||||
|
|
||||||
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||||
|
|
||||||
|
model_buffers = list(model.buffers())
|
||||||
|
for non_persistent_buffer in non_persistent_buffers.values():
|
||||||
|
found_buffer = False
|
||||||
|
for i, model_buffer in enumerate(model_buffers):
|
||||||
|
if torch.equal(non_persistent_buffer, model_buffer):
|
||||||
|
found_buffer = True
|
||||||
|
break
|
||||||
|
|
||||||
|
self.assertTrue(found_buffer)
|
||||||
|
model_buffers.pop(i)
|
||||||
|
|
||||||
|
models_equal = True
|
||||||
|
for layer_name, p1 in model_state_dict.items():
|
||||||
|
if layer_name in loaded_model_state_dict:
|
||||||
|
p2 = loaded_model_state_dict[layer_name]
|
||||||
|
if p1.data.ne(p2.data).sum() > 0:
|
||||||
|
models_equal = False
|
||||||
|
|
||||||
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
|
||||||
|
# (Even with this call, there are still memory leak by ~0.04MB)
|
||||||
|
self.clear_torch_jit_class_registry()
|
||||||
|
|
||||||
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||||
def test_hidden_states_output(self):
|
def test_hidden_states_output(self):
|
||||||
pass
|
pass
|
||||||
@@ -942,7 +1161,7 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
|
|||||||
def test_get_image_features(self):
|
def test_get_image_features(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
keys_to_pop = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"]
|
keys_to_pop = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
|
||||||
|
|
||||||
for key in keys_to_pop:
|
for key in keys_to_pop:
|
||||||
inputs_dict.pop(key)
|
inputs_dict.pop(key)
|
||||||
@@ -962,7 +1181,7 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi
|
|||||||
def test_get_qformer_features(self):
|
def test_get_qformer_features(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
keys_to_pop = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"]
|
keys_to_pop = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
|
||||||
|
|
||||||
for key in keys_to_pop:
|
for key in keys_to_pop:
|
||||||
inputs_dict.pop(key)
|
inputs_dict.pop(key)
|
||||||
@@ -1072,7 +1291,7 @@ class Blip2TextModelWithProjectionTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
|
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = True
|
||||||
test_attention_outputs = False
|
test_attention_outputs = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
|
|
||||||
@@ -1396,7 +1615,7 @@ class Blip2TextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = True
|
||||||
test_attention_outputs = False
|
test_attention_outputs = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
|
|
||||||
|
|||||||
@@ -459,7 +459,7 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene
|
|||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = True
|
||||||
test_attention_outputs = False
|
test_attention_outputs = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
_is_composite = True
|
_is_composite = True
|
||||||
|
|||||||
@@ -479,7 +479,7 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
|
|||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = True
|
||||||
test_attention_outputs = False
|
test_attention_outputs = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
_is_composite = True
|
_is_composite = True
|
||||||
|
|||||||
@@ -1811,6 +1811,7 @@ class ModelTesterMixin:
|
|||||||
original_config,
|
original_config,
|
||||||
inputs_dict,
|
inputs_dict,
|
||||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
inputs_dict.pop("labels", None)
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
config = copy.deepcopy(original_config)
|
config = copy.deepcopy(original_config)
|
||||||
@@ -1988,6 +1989,7 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
original_config.tie_word_embeddings = False
|
original_config.tie_word_embeddings = False
|
||||||
|
inputs_dict.pop("labels", None)
|
||||||
|
|
||||||
# if model cannot untied embeddings -> leave test
|
# if model cannot untied embeddings -> leave test
|
||||||
if original_config.tie_word_embeddings:
|
if original_config.tie_word_embeddings:
|
||||||
|
|||||||
Reference in New Issue
Block a user