Blip: get/set input embeddings correctly (#34152)
* set-get embeds * add tests * fix tests * remove * return dict True * fix tests * why did i remove this * enabel torchscript tests
This commit is contained in:
committed by
GitHub
parent
b53e44e847
commit
6beb3f1691
@@ -444,7 +444,7 @@ class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_resize_embeddings = True
|
||||
test_attention_outputs = False
|
||||
|
||||
def setUp(self):
|
||||
@@ -738,7 +738,6 @@ class BlipTextImageModelsModelTester:
|
||||
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"labels": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
}
|
||||
@@ -787,10 +786,10 @@ class BlipVQAModelTester:
|
||||
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"labels": input_ids,
|
||||
"decoder_input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
"labels": input_ids,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
@@ -802,7 +801,7 @@ class BlipVQAModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_resize_embeddings = True
|
||||
test_attention_outputs = False
|
||||
test_torchscript = False
|
||||
|
||||
@@ -811,7 +810,6 @@ class BlipVQAModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
def _prepare_inputs_for_vqa(self):
|
||||
_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
inputs_dict["labels"] = inputs_dict["input_ids"]
|
||||
inputs_dict["decoder_input_ids"] = inputs_dict["input_ids"]
|
||||
inputs_dict.pop("return_loss")
|
||||
return inputs_dict
|
||||
@@ -882,7 +880,7 @@ class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_resize_embeddings = True
|
||||
test_attention_outputs = False
|
||||
test_torchscript = False
|
||||
|
||||
@@ -1110,7 +1108,7 @@ class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_resize_embeddings = True
|
||||
test_attention_outputs = False
|
||||
test_torchscript = False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user