Adds image-guided object detection support to OWL-ViT (#20136)

Adds image-guided object detection method to OwlViTForObjectDetection class as described in the original paper. One-shot/ image-guided object detection enables users to use a query image to search for similar objects in the input image.

Co-Authored-By: Dhruv Karan k4r4n.dhruv@gmail.com
This commit is contained in:
Alara Dirik
2022-11-16 09:07:46 +03:00
committed by GitHub
parent 0d0d77693f
commit a00b7e85ea
7 changed files with 582 additions and 138 deletions

View File

@@ -227,6 +227,23 @@ class OwlViTProcessorTest(unittest.TestCase):
self.assertListEqual(list(input_ids[0]), predicted_ids[0])
self.assertListEqual(list(input_ids[1]), predicted_ids[1])
def test_processor_case2(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = OwlViTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
image_input = self.prepare_image_inputs()
query_input = self.prepare_image_inputs()
inputs = processor(images=image_input, query_images=query_input)
self.assertListEqual(list(inputs.keys()), ["query_pixel_values", "pixel_values"])
# test if it raises when no input is passed
with pytest.raises(ValueError):
processor()
def test_tokenizer_decode(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
@@ -239,16 +256,3 @@ class OwlViTProcessorTest(unittest.TestCase):
decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = OwlViTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
self.assertListEqual(list(inputs.keys()), processor.model_input_names)