From 0ae58204c68f55842a8f70d835e3ed9da12acf23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C4=B0dil=20S=C3=BClo?= Date: Wed, 21 Dec 2022 13:23:45 +0100 Subject: [PATCH] Add visual prompt to processor of CLIPSeg model (#20816) Adds visual_prompt argument to CLIPSegProcessor to enable image-guided segmentation --- .../models/clipseg/processing_clipseg.py | 28 ++++++++++++++++--- .../models/clipseg/test_processor_clipseg.py | 19 ++++++++++++- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/clipseg/processing_clipseg.py b/src/transformers/models/clipseg/processing_clipseg.py index cb12269295..df3705e99e 100644 --- a/src/transformers/models/clipseg/processing_clipseg.py +++ b/src/transformers/models/clipseg/processing_clipseg.py @@ -56,7 +56,7 @@ class CLIPSegProcessor(ProcessorMixin): super().__init__(image_processor, tokenizer) - def __call__(self, text=None, images=None, return_tensors=None, **kwargs): + def __call__(self, text=None, images=None, visual_prompt=None, return_tensors=None, **kwargs): """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode @@ -73,6 +73,10 @@ class CLIPSegProcessor(ProcessorMixin): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. + visual_prompt (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The visual prompt image or batch of images to be prepared. Each visual prompt image can be a PIL image, + NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape + (C, H, W), where C is a number of channels, H and W are image height and width. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: @@ -91,21 +95,37 @@ class CLIPSegProcessor(ProcessorMixin): `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ + if text is None and visual_prompt is None and images is None: + raise ValueError("You have to specify either text, visual prompt or images.") - if text is None and images is None: - raise ValueError("You have to specify either text or images. Both cannot be none.") + if text is not None and visual_prompt is not None: + raise ValueError("You have to specify exactly one type of prompt. Either text or visual prompt.") if text is not None: encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs) + if visual_prompt is not None: + prompt_features = self.image_processor(visual_prompt, return_tensors=return_tensors, **kwargs) + if images is not None: image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs) - if text is not None and images is not None: + if visual_prompt is not None and images is not None: + encoding = { + "pixel_values": image_features.pixel_values, + "conditional_pixel_values": prompt_features.pixel_values, + } + return encoding + elif text is not None and images is not None: encoding["pixel_values"] = image_features.pixel_values return encoding elif text is not None: return encoding + elif visual_prompt is not None: + encoding = { + "conditional_pixel_values": prompt_features.pixel_values, + } + return encoding else: return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) diff --git a/tests/models/clipseg/test_processor_clipseg.py b/tests/models/clipseg/test_processor_clipseg.py index 03ea5dd962..2bc82dd022 100644 --- a/tests/models/clipseg/test_processor_clipseg.py +++ b/tests/models/clipseg/test_processor_clipseg.py @@ -157,7 +157,7 @@ class CLIPSegProcessorTest(unittest.TestCase): for key in encoded_tok.keys(): self.assertListEqual(encoded_tok[key], encoded_processor[key]) - def test_processor(self): + def test_processor_text(self): image_processor = self.get_image_processor() tokenizer = self.get_tokenizer() @@ -174,6 +174,23 @@ class CLIPSegProcessorTest(unittest.TestCase): with pytest.raises(ValueError): processor() + def test_processor_visual_prompt(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + + processor = CLIPSegProcessor(tokenizer=tokenizer, image_processor=image_processor) + + image_input = self.prepare_image_inputs() + visual_prompt_input = self.prepare_image_inputs() + + inputs = processor(images=image_input, visual_prompt=visual_prompt_input) + + self.assertListEqual(list(inputs.keys()), ["pixel_values", "conditional_pixel_values"]) + + # test if it raises when no input is passed + with pytest.raises(ValueError): + processor() + def test_tokenizer_decode(self): image_processor = self.get_image_processor() tokenizer = self.get_tokenizer()