Blip: get/set input embeddings correctly (#34152)

* set-get embeds

* add tests

* fix tests

* remove

* return dict True

* fix tests

* why did i remove this

* enabel torchscript tests
This commit is contained in:
Raushan Turganbay
2024-11-01 08:39:39 +01:00
committed by GitHub
parent b53e44e847
commit 6beb3f1691
8 changed files with 288 additions and 32 deletions

View File

@@ -444,7 +444,7 @@ class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_resize_embeddings = True
test_attention_outputs = False
def setUp(self):
@@ -738,7 +738,6 @@ class BlipTextImageModelsModelTester:
config, input_ids, attention_mask, pixel_values = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"labels": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
}
@@ -787,10 +786,10 @@ class BlipVQAModelTester:
config, input_ids, attention_mask, pixel_values = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"labels": input_ids,
"decoder_input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"labels": input_ids,
}
return config, inputs_dict
@@ -802,7 +801,7 @@ class BlipVQAModelTest(ModelTesterMixin, unittest.TestCase):
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_resize_embeddings = True
test_attention_outputs = False
test_torchscript = False
@@ -811,7 +810,6 @@ class BlipVQAModelTest(ModelTesterMixin, unittest.TestCase):
def _prepare_inputs_for_vqa(self):
_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict["labels"] = inputs_dict["input_ids"]
inputs_dict["decoder_input_ids"] = inputs_dict["input_ids"]
inputs_dict.pop("return_loss")
return inputs_dict
@@ -882,7 +880,7 @@ class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_resize_embeddings = True
test_attention_outputs = False
test_torchscript = False
@@ -1110,7 +1108,7 @@ class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase):
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_resize_embeddings = True
test_attention_outputs = False
test_torchscript = False