Add visual prompt to processor of CLIPSeg model (#20816)

Adds visual_prompt argument to CLIPSegProcessor to enable image-guided segmentation
This commit is contained in:
İdil Sülo
2022-12-21 13:23:45 +01:00
committed by GitHub
parent 2da82bb4a7
commit 0ae58204c6
2 changed files with 42 additions and 5 deletions

View File

@@ -56,7 +56,7 @@ class CLIPSegProcessor(ProcessorMixin):
super().__init__(image_processor, tokenizer) super().__init__(image_processor, tokenizer)
def __call__(self, text=None, images=None, return_tensors=None, **kwargs): def __call__(self, text=None, images=None, visual_prompt=None, return_tensors=None, **kwargs):
""" """
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode and `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode
@@ -73,6 +73,10 @@ class CLIPSegProcessor(ProcessorMixin):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width. number of channels, H and W are image height and width.
visual_prompt (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The visual prompt image or batch of images to be prepared. Each visual prompt image can be a PIL image,
NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape
(C, H, W), where C is a number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*): return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are: If set, will return tensors of a particular framework. Acceptable values are:
@@ -91,21 +95,37 @@ class CLIPSegProcessor(ProcessorMixin):
`None`). `None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
""" """
if text is None and visual_prompt is None and images is None:
raise ValueError("You have to specify either text, visual prompt or images.")
if text is None and images is None: if text is not None and visual_prompt is not None:
raise ValueError("You have to specify either text or images. Both cannot be none.") raise ValueError("You have to specify exactly one type of prompt. Either text or visual prompt.")
if text is not None: if text is not None:
encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs) encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
if visual_prompt is not None:
prompt_features = self.image_processor(visual_prompt, return_tensors=return_tensors, **kwargs)
if images is not None: if images is not None:
image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs) image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs)
if text is not None and images is not None: if visual_prompt is not None and images is not None:
encoding = {
"pixel_values": image_features.pixel_values,
"conditional_pixel_values": prompt_features.pixel_values,
}
return encoding
elif text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values encoding["pixel_values"] = image_features.pixel_values
return encoding return encoding
elif text is not None: elif text is not None:
return encoding return encoding
elif visual_prompt is not None:
encoding = {
"conditional_pixel_values": prompt_features.pixel_values,
}
return encoding
else: else:
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)

View File

@@ -157,7 +157,7 @@ class CLIPSegProcessorTest(unittest.TestCase):
for key in encoded_tok.keys(): for key in encoded_tok.keys():
self.assertListEqual(encoded_tok[key], encoded_processor[key]) self.assertListEqual(encoded_tok[key], encoded_processor[key])
def test_processor(self): def test_processor_text(self):
image_processor = self.get_image_processor() image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
@@ -174,6 +174,23 @@ class CLIPSegProcessorTest(unittest.TestCase):
with pytest.raises(ValueError): with pytest.raises(ValueError):
processor() processor()
def test_processor_visual_prompt(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = CLIPSegProcessor(tokenizer=tokenizer, image_processor=image_processor)
image_input = self.prepare_image_inputs()
visual_prompt_input = self.prepare_image_inputs()
inputs = processor(images=image_input, visual_prompt=visual_prompt_input)
self.assertListEqual(list(inputs.keys()), ["pixel_values", "conditional_pixel_values"])
# test if it raises when no input is passed
with pytest.raises(ValueError):
processor()
def test_tokenizer_decode(self): def test_tokenizer_decode(self):
image_processor = self.get_image_processor() image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()