uniformize kwargs for SAM (#34578)

* Make kwargs uniform for SAM

* Remove unused attribute

* Make point_pad_value part of image_kwargs

* Update annotations

* Code review - use existing methods

* Use ProcessorTesterMixin

* Do not add ProcessorTesterMixin everywhere
This commit is contained in:
Tibor Reiss
2024-12-23 13:54:57 +01:00
committed by GitHub
parent 2bb60982ac
commit e10be82b71
2 changed files with 81 additions and 29 deletions

View File

@@ -17,13 +17,14 @@ Processor class for SAM.
""" """
from copy import deepcopy from copy import deepcopy
from typing import Optional, Union from typing import List, Optional, Union
import numpy as np import numpy as np
from ...processing_utils import ProcessorMixin from ...image_utils import ImageInput, VideoInput
from ...tokenization_utils_base import BatchEncoding from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin
from ...utils import TensorType, is_tf_available, is_torch_available from ...tokenization_utils_base import AudioInput, BatchEncoding, PreTokenizedInput, TextInput
from ...utils import is_tf_available, is_torch_available
if is_torch_available(): if is_torch_available():
@@ -33,6 +34,23 @@ if is_tf_available():
import tensorflow as tf import tensorflow as tf
class SamImagesKwargs(ImagesKwargs):
segmentation_maps: Optional[ImageInput]
input_points: Optional[List[List[float]]]
input_labels: Optional[List[List[int]]]
input_boxes: Optional[List[List[List[float]]]]
point_pad_value: Optional[int]
class SamProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: SamImagesKwargs
_defaults = {
"images_kwargs": {
"point_pad_value": -10,
}
}
class SamProcessor(ProcessorMixin): class SamProcessor(ProcessorMixin):
r""" r"""
Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a
@@ -48,32 +66,50 @@ class SamProcessor(ProcessorMixin):
attributes = ["image_processor"] attributes = ["image_processor"]
image_processor_class = "SamImageProcessor" image_processor_class = "SamImageProcessor"
# For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details.
optional_call_args = [
"segmentation_maps",
"input_points",
"input_labels",
"input_boxes",
]
def __init__(self, image_processor): def __init__(self, image_processor):
super().__init__(image_processor) super().__init__(image_processor)
self.current_processor = self.image_processor
self.point_pad_value = -10
self.target_size = self.image_processor.size["longest_edge"] self.target_size = self.image_processor.size["longest_edge"]
def __call__( def __call__(
self, self,
images=None, images: Optional[ImageInput] = None,
segmentation_maps=None, # The following is to capture `segmentation_maps`, `input_points`, `input_labels` and `input_boxes`
input_points=None, # arguments that may be passed as a positional argument.
input_labels=None, # See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details,
input_boxes=None, # or this conversation for more context:
return_tensors: Optional[Union[str, TensorType]] = None, # https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116
# This behavior is only needed for backward compatibility and will be removed in future versions.
*args, # to be deprecated
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
audio: Optional[AudioInput] = None,
video: Optional[VideoInput] = None,
**kwargs, **kwargs,
) -> BatchEncoding: ) -> BatchEncoding:
""" """
This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D
points and bounding boxes for the model if they are provided. points and bounding boxes for the model if they are provided.
""" """
output_kwargs = self._merge_kwargs(
SamProcessorKwargs,
tokenizer_init_kwargs={},
**kwargs,
**self.prepare_and_validate_optional_call_args(*args),
)
input_points = output_kwargs["images_kwargs"].pop("input_points", None)
input_labels = output_kwargs["images_kwargs"].pop("input_labels", None)
input_boxes = output_kwargs["images_kwargs"].pop("input_boxes", None)
encoding_image_processor = self.image_processor( encoding_image_processor = self.image_processor(
images, images,
segmentation_maps=segmentation_maps, **output_kwargs["images_kwargs"],
return_tensors=return_tensors,
**kwargs,
) )
# pop arguments that are not used in the foward but used nevertheless # pop arguments that are not used in the foward but used nevertheless
@@ -94,7 +130,8 @@ class SamProcessor(ProcessorMixin):
input_points=input_points, input_points=input_points,
input_labels=input_labels, input_labels=input_labels,
input_boxes=input_boxes, input_boxes=input_boxes,
return_tensors=return_tensors, return_tensors=output_kwargs["common_kwargs"].get("return_tensors"),
point_pad_value=output_kwargs["images_kwargs"].get("point_pad_value"),
) )
return encoding_image_processor return encoding_image_processor
@@ -107,6 +144,7 @@ class SamProcessor(ProcessorMixin):
input_labels=None, input_labels=None,
input_boxes=None, input_boxes=None,
return_tensors="pt", return_tensors="pt",
point_pad_value=-10,
): ):
if input_points is not None: if input_points is not None:
if len(original_sizes) != len(input_points): if len(original_sizes) != len(input_points):
@@ -121,7 +159,9 @@ class SamProcessor(ProcessorMixin):
# check that all arrays have the same shape # check that all arrays have the same shape
if not all(point.shape == input_points[0].shape for point in input_points): if not all(point.shape == input_points[0].shape for point in input_points):
if input_labels is not None: if input_labels is not None:
input_points, input_labels = self._pad_points_and_labels(input_points, input_labels) input_points, input_labels = self._pad_points_and_labels(
input_points, input_labels, point_pad_value
)
input_points = np.array(input_points) input_points = np.array(input_points)
@@ -174,7 +214,7 @@ class SamProcessor(ProcessorMixin):
return encoding_image_processor return encoding_image_processor
def _pad_points_and_labels(self, input_points, input_labels): def _pad_points_and_labels(self, input_points, input_labels, point_pad_value):
r""" r"""
The method pads the 2D points and labels to the maximum number of points in the batch. The method pads the 2D points and labels to the maximum number of points in the batch.
""" """
@@ -183,9 +223,9 @@ class SamProcessor(ProcessorMixin):
for i, point in enumerate(input_points): for i, point in enumerate(input_points):
if point.shape[0] != expected_nb_points: if point.shape[0] != expected_nb_points:
point = np.concatenate( point = np.concatenate(
[point, np.zeros((expected_nb_points - point.shape[0], 2)) + self.point_pad_value], axis=0 [point, np.zeros((expected_nb_points - point.shape[0], 2)) + point_pad_value], axis=0
) )
input_labels[i] = np.append(input_labels[i], [self.point_pad_value]) input_labels[i] = np.append(input_labels[i], [point_pad_value])
processed_input_points.append(point) processed_input_points.append(point)
input_points = processed_input_points input_points = processed_input_points
return input_points, input_labels return input_points, input_labels

View File

@@ -26,7 +26,7 @@ from transformers.testing_utils import (
) )
from transformers.utils import is_tf_available, is_torch_available, is_vision_available from transformers.utils import is_tf_available, is_torch_available, is_vision_available
from ...test_processing_common import prepare_image_inputs from ...test_processing_common import ProcessorTesterMixin, prepare_image_inputs
if is_vision_available(): if is_vision_available():
@@ -43,7 +43,9 @@ if is_tf_available():
@require_vision @require_vision
@require_torchvision @require_torchvision
class SamProcessorTest(unittest.TestCase): class SamProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = SamProcessor
def setUp(self): def setUp(self):
self.tmpdirname = tempfile.mkdtemp() self.tmpdirname = tempfile.mkdtemp()
image_processor = SamImageProcessor() image_processor = SamImageProcessor()
@@ -56,11 +58,6 @@ class SamProcessorTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
shutil.rmtree(self.tmpdirname) shutil.rmtree(self.tmpdirname)
# Processor tester class can't use ProcessorTesterMixin atm because the processor is atypical e.g. only contains an image processor
def prepare_image_inputs(self):
"""This function prepares a list of PIL images."""
return prepare_image_inputs()
def prepare_mask_inputs(self): def prepare_mask_inputs(self):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True. or a list of PyTorch tensors if one specifies torchify=True.
@@ -69,6 +66,21 @@ class SamProcessorTest(unittest.TestCase):
mask_inputs = [Image.fromarray(x) for x in mask_inputs] mask_inputs = [Image.fromarray(x) for x in mask_inputs]
return mask_inputs return mask_inputs
def test_chat_template_save_loading(self):
self.skipTest("SamProcessor does not have a tokenizer")
def test_image_processor_defaults_preserved_by_image_kwargs(self):
self.skipTest("SamProcessor does not have a tokenizer")
def test_kwargs_overrides_default_image_processor_kwargs(self):
self.skipTest("SamProcessor does not have a tokenizer")
def test_kwargs_overrides_default_tokenizer_kwargs(self):
self.skipTest("SamProcessor does not have a tokenizer")
def test_tokenizer_defaults_preserved_by_kwargs(self):
self.skipTest("SamProcessor does not have a tokenizer")
def test_save_load_pretrained_additional_features(self): def test_save_load_pretrained_additional_features(self):
processor = SamProcessor(image_processor=self.get_image_processor()) processor = SamProcessor(image_processor=self.get_image_processor())
processor.save_pretrained(self.tmpdirname) processor.save_pretrained(self.tmpdirname)
@@ -165,7 +177,7 @@ class TFSamProcessorTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
shutil.rmtree(self.tmpdirname) shutil.rmtree(self.tmpdirname)
# Processor tester class can't use ProcessorTesterMixin as processor is atypical e.g. only contains an image processor and it assumes torch # This is to avoid repeating the skipping of the common tests
def prepare_image_inputs(self): def prepare_image_inputs(self):
"""This function prepares a list of PIL images.""" """This function prepares a list of PIL images."""
return prepare_image_inputs() return prepare_image_inputs()
@@ -248,7 +260,7 @@ class SamProcessorEquivalenceTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
shutil.rmtree(self.tmpdirname) shutil.rmtree(self.tmpdirname)
# Processor tester class can't use ProcessorTesterMixin atm because the processor is atypical e.g. only contains an image processor # This is to avoid repeating the skipping of the common tests
def prepare_image_inputs(self): def prepare_image_inputs(self):
"""This function prepares a list of PIL images.""" """This function prepares a list of PIL images."""
return prepare_image_inputs() return prepare_image_inputs()