🚨 Add Blip2ForImageTextRetrieval (#29261)
* add Blip2ForImageTextRetrieval * use one line and remove unnecessary space in tests Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * use value from the config, rather than hardcoded * change order of params in Blip2QFormerModel.forward * update docstring * fix style * update test_inference_opt * move embeddings out of Blip2QFormerModel * remove from_vision_qformer_configs * remove autocast float16 in Blip2QFormerModel * rename fiels into vision_projection,text_projection,use_image_text_matching_head * use CLIPOutput for Blip2ImageTextMatchingModelOutput * remove past_key_values_length from Blip2TextEmbeddings * fix small typo in the CLIPOutput docstring * add Blip2ForImageTextRetrieval to Zero Shot Image Classification mapping * update docstring and add require_torch_fp16 * rollback test_inference_opt * use use_image_text_matching_head=True in convert * skip test_model_get_set_embeddings * fix create_rename_keys error on new itm fields * revert to do scale after dot product between "query" and "key" * fix ValueError on convert script for blip2-opt-2.7b * update org of paths to Salesforce * add is_pipeline_test_to_skip for VisualQuestionAnsweringPipelineTests * [run_slow] blip_2 * removed Blip2ForImageTextRetrieval from IGNORE_NON_AUTO_CONFIGURED * fix docstring of Blip2ImageTextMatchingModelOutput * [run_slow] blip_2 * fix multi-gpu tests * [run_slow] blip_2 * [run_slow] blip_2 --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -279,3 +279,46 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase):
|
||||
]
|
||||
* 5,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_blip2_model_pt(self):
|
||||
image_classifier = pipeline(
|
||||
task="zero-shot-image-classification",
|
||||
model="Salesforce/blip2-itm-vit-g",
|
||||
)
|
||||
# This is an image of 2 cats with remotes and no planes
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
output = image_classifier(
|
||||
image,
|
||||
candidate_labels=["2 cats", "a plane", "a remote"],
|
||||
tokenizer_kwargs={"return_token_type_ids": False},
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(output),
|
||||
[
|
||||
{"score": 0.369, "label": "2 cats"},
|
||||
{"score": 0.333, "label": "a remote"},
|
||||
{"score": 0.297, "label": "a plane"},
|
||||
],
|
||||
)
|
||||
|
||||
output = image_classifier(
|
||||
[image] * 5,
|
||||
candidate_labels=["2 cats", "a plane", "a remote"],
|
||||
batch_size=2,
|
||||
tokenizer_kwargs={"return_token_type_ids": False},
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
nested_simplify(output),
|
||||
[
|
||||
[
|
||||
{"score": 0.369, "label": "2 cats"},
|
||||
{"score": 0.333, "label": "a remote"},
|
||||
{"score": 0.297, "label": "a plane"},
|
||||
]
|
||||
]
|
||||
* 5,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user