Add automatic-mask-generation pipeline for Segment Anything Model (SAM) (#22840)
* cleanup * updates * more refactoring * make style * update inits * support other inputs in base * update based on review Co-authored-by: Nicolas Patry <patry.nicolas@gmail.com> * Update tests/pipelines/test_pipelines_automatic_mask_generation.py Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * update * fixup * TODO x and y to refactor, _h _w refactored here * update docstring * more nits * style on these * more doc fix * rename variables * update * updates * style * update * fix `_mask_to_rle_pytorch` * styling * fix ask to rle, wrong outputs * add device arg * update * more updates, fix tets * udpate * update docstrings * styling * fixup * add notebook on the docs * update orginal sizes * fix docstring * updat condition on point_per-batch * updates tests * fix CI test * extend is required, append does not work! * fixup * fix CI tests * whit pixels left * address doc comments * fix doc * slow pipeline tests * update auto init * add revision * make fixup * update p!ipoeline tag when calling tests * alphabeitcal order in inits * fix copies * last style nits * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * reformat docstring * more reformat * address most of the comments * Update src/transformers/pipelines/mask_generation.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * final refactor * Update src/transformers/models/sam/image_processing_sam.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fixup and fix slow tests * revert --------- Co-authored-by: Nicolas Patry <patry.nicolas@gmail.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: younesbelkada <younesbelkada@gmail.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -63,6 +63,7 @@ from .fill_mask import FillMaskPipeline
|
||||
from .image_classification import ImageClassificationPipeline
|
||||
from .image_segmentation import ImageSegmentationPipeline
|
||||
from .image_to_text import ImageToTextPipeline
|
||||
from .mask_generation import MaskGenerationPipeline
|
||||
from .object_detection import ObjectDetectionPipeline
|
||||
from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline
|
||||
from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline
|
||||
@@ -124,6 +125,7 @@ if is_torch_available():
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForImageSegmentation,
|
||||
AutoModelForMaskedLM,
|
||||
AutoModelForMaskGeneration,
|
||||
AutoModelForObjectDetection,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSemanticSegmentation,
|
||||
@@ -384,6 +386,13 @@ SUPPORTED_TASKS = {
|
||||
"default": {"model": {"pt": ("MCG-NJU/videomae-base-finetuned-kinetics", "4800870")}},
|
||||
"type": "video",
|
||||
},
|
||||
"mask-generation": {
|
||||
"impl": MaskGenerationPipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForMaskGeneration,) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": ("facebook/sam-vit-huge", "997b15")}},
|
||||
"type": "multimodal",
|
||||
},
|
||||
}
|
||||
|
||||
NO_FEATURE_EXTRACTOR_TASKS = set()
|
||||
@@ -536,6 +545,7 @@ def pipeline(
|
||||
- `"image-classification"`: will return a [`ImageClassificationPipeline`].
|
||||
- `"image-segmentation"`: will return a [`ImageSegmentationPipeline`].
|
||||
- `"image-to-text"`: will return a [`ImageToTextPipeline`].
|
||||
- `"mask-generation"`: will return a [`MaskGenerationPipeline`].
|
||||
- `"object-detection"`: will return a [`ObjectDetectionPipeline`].
|
||||
- `"question-answering"`: will return a [`QuestionAnsweringPipeline`].
|
||||
- `"summarization"`: will return a [`SummarizationPipeline`].
|
||||
|
||||
@@ -97,6 +97,8 @@ def _pad(items, key, padding_value, padding_side):
|
||||
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
|
||||
elif dim == 3:
|
||||
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
|
||||
elif dim == 4:
|
||||
tensor = torch.zeros((batch_size, max_length, shape[-2], shape[-1]), dtype=dtype) + padding_value
|
||||
|
||||
for i, item in enumerate(items):
|
||||
if dim == 2:
|
||||
@@ -109,6 +111,12 @@ def _pad(items, key, padding_value, padding_side):
|
||||
tensor[i, -len(item[key][0]) :, :] = item[key][0].clone()
|
||||
else:
|
||||
tensor[i, : len(item[key][0]), :] = item[key][0].clone()
|
||||
elif dim == 4:
|
||||
if padding_side == "left":
|
||||
tensor[i, -len(item[key][0]) :, :, :] = item[key][0].clone()
|
||||
else:
|
||||
tensor[i, : len(item[key][0]), :, :] = item[key][0].clone()
|
||||
|
||||
return tensor
|
||||
else:
|
||||
return [item[key] for item in items]
|
||||
|
||||
@@ -81,11 +81,11 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
)
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocessor_kwargs = {}
|
||||
preprocess_kwargs = {}
|
||||
postprocess_kwargs = {}
|
||||
if "subtask" in kwargs:
|
||||
postprocess_kwargs["subtask"] = kwargs["subtask"]
|
||||
preprocessor_kwargs["subtask"] = kwargs["subtask"]
|
||||
preprocess_kwargs["subtask"] = kwargs["subtask"]
|
||||
if "threshold" in kwargs:
|
||||
postprocess_kwargs["threshold"] = kwargs["threshold"]
|
||||
if "mask_threshold" in kwargs:
|
||||
@@ -93,7 +93,7 @@ class ImageSegmentationPipeline(Pipeline):
|
||||
if "overlap_mask_area_threshold" in kwargs:
|
||||
postprocess_kwargs["overlap_mask_area_threshold"] = kwargs["overlap_mask_area_threshold"]
|
||||
|
||||
return preprocessor_kwargs, {}, postprocess_kwargs
|
||||
return preprocess_kwargs, {}, postprocess_kwargs
|
||||
|
||||
def __call__(self, images, **kwargs) -> Union[Predictions, List[Prediction]]:
|
||||
"""
|
||||
|
||||
286
src/transformers/pipelines/mask_generation.py
Normal file
286
src/transformers/pipelines/mask_generation.py
Normal file
@@ -0,0 +1,286 @@
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
from ..image_utils import load_image
|
||||
from ..utils import (
|
||||
add_end_docstrings,
|
||||
is_torch_available,
|
||||
logging,
|
||||
requires_backends,
|
||||
)
|
||||
from .base import PIPELINE_INIT_ARGS, ChunkPipeline
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_MASK_GENERATION_MAPPING
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||
class MaskGenerationPipeline(ChunkPipeline):
|
||||
"""
|
||||
Automatic mask generation for images using `SamForMaskGeneration`. This pipeline predicts binary masks for an
|
||||
image, given an image. It is a `ChunkPipeline` because you can seperate the points in a mini-batch in order to
|
||||
avoid OOM issues. Use the `points_per_batch` argument to control the number of points that will be processed at the
|
||||
same time. Default is `64`.
|
||||
|
||||
The pipeline works in 3 steps:
|
||||
1. `preprocess`: A grid of 1024 points evenly separated is generated along with bounding boxes and point
|
||||
labels.
|
||||
For more details on how the points and bounding boxes are created, check the `_generate_crop_boxes`
|
||||
function. The image is also preprocessed using the `image_processor`. This function `yields` a minibatch of
|
||||
`points_per_batch`.
|
||||
|
||||
2. `forward`: feeds the outputs of `preprocess` to the model. The image embedding is computed only once.
|
||||
Calls both `self.model.get_image_embeddings` and makes sure that the gradients are not computed, and the
|
||||
tensors and models are on the same device.
|
||||
|
||||
3. `postprocess`: The most important part of the automatic mask generation happens here. Three steps
|
||||
are induced:
|
||||
- image_processor.postprocess_masks (run on each minibatch loop): takes in the raw output masks,
|
||||
resizes them according
|
||||
to the image size, and transforms there to binary masks.
|
||||
- image_processor.filter_masks (on each minibatch loop): uses both `pred_iou_thresh` and
|
||||
`stability_scores`. Also
|
||||
applies a variety of filters based on non maximum suppression to remove bad masks.
|
||||
- image_processor.postprocess_masks_for_amg applies the NSM on the mask to only keep relevant ones.
|
||||
|
||||
Arguments:
|
||||
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
|
||||
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
|
||||
[`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow.
|
||||
tokenizer ([`PreTrainedTokenizer`]):
|
||||
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
|
||||
[`PreTrainedTokenizer`].
|
||||
feature_extractor ([`SequenceFeatureExtractor`]):
|
||||
The feature extractor that will be used by the pipeline to encode the input.
|
||||
points_per_batch (*optional*, int, default to 64):
|
||||
Sets the number of points run simultaneously by the model. Higher numbers may be faster but use more GPU
|
||||
memory.
|
||||
output_bboxes_mask (`bool`, *optional*, default to `False`):
|
||||
Whether or not to output the bounding box predictions.
|
||||
output_rle_masks (`bool`, *optional*, default to `False`):
|
||||
Whether or not to output the masks in `RLE` format
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
|
||||
>>> generator = pipeline(model="facebook/sam-vit-h", task="mask-generation")
|
||||
>>> outputs = generator(
|
||||
... "http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
... )
|
||||
|
||||
>>> outputs = generator(
|
||||
... "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", points_per_batch=128
|
||||
... )
|
||||
```
|
||||
|
||||
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
||||
|
||||
This segmentation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
||||
`"mask-generation"`.
|
||||
|
||||
See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=mask-generation).
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
requires_backends(self, "vision")
|
||||
requires_backends(self, "torch")
|
||||
|
||||
if self.framework != "pt":
|
||||
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
|
||||
|
||||
self.check_model_type(MODEL_FOR_MASK_GENERATION_MAPPING)
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocess_kwargs = {}
|
||||
postprocess_kwargs = {}
|
||||
forward_params = {}
|
||||
# preprocess args
|
||||
if "points_per_batch" in kwargs:
|
||||
preprocess_kwargs["points_per_batch"] = kwargs["points_per_batch"]
|
||||
if "points_per_crop" in kwargs:
|
||||
preprocess_kwargs["points_per_crop"] = kwargs["points_per_crop"]
|
||||
if "crops_n_layers" in kwargs:
|
||||
preprocess_kwargs["crops_n_layers"] = kwargs["crops_n_layers"]
|
||||
if "crop_overlap_ratio" in kwargs:
|
||||
preprocess_kwargs["crop_overlap_ratio"] = kwargs["crop_overlap_ratio"]
|
||||
if "crop_n_points_downscale_factor" in kwargs:
|
||||
preprocess_kwargs["crop_n_points_downscale_factor"] = kwargs["crop_n_points_downscale_factor"]
|
||||
# postprocess args
|
||||
if "pred_iou_thresh" in kwargs:
|
||||
forward_params["pred_iou_thresh"] = kwargs["pred_iou_thresh"]
|
||||
if "stability_score_offset" in kwargs:
|
||||
forward_params["stability_score_offset"] = kwargs["stability_score_offset"]
|
||||
if "mask_threshold" in kwargs:
|
||||
forward_params["mask_threshold"] = kwargs["mask_threshold"]
|
||||
if "stability_score_thresh" in kwargs:
|
||||
forward_params["stability_score_thresh"] = kwargs["stability_score_thresh"]
|
||||
if "crops_nms_thresh" in kwargs:
|
||||
postprocess_kwargs["crops_nms_thresh"] = kwargs["crops_nms_thresh"]
|
||||
if "output_rle_mask" in kwargs:
|
||||
postprocess_kwargs["output_rle_mask"] = kwargs["output_rle_mask"]
|
||||
if "output_bboxes_mask" in kwargs:
|
||||
postprocess_kwargs["output_bboxes_mask"] = kwargs["output_bboxes_mask"]
|
||||
return preprocess_kwargs, forward_params, postprocess_kwargs
|
||||
|
||||
def __call__(self, image, *args, num_workers=None, batch_size=None, **kwargs):
|
||||
"""
|
||||
Generates binary segmentation masks
|
||||
|
||||
Args:
|
||||
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
|
||||
Image or list of images.
|
||||
mask_threshold (`float`, *optional*, defaults to 0.0):
|
||||
Threshold to use when turning the predicted masks into binary values.
|
||||
pred_iou_thresh (`float`, *optional*, defaults to 0.88):
|
||||
A filtering threshold in `[0,1]` applied on the model's predicted mask quality.
|
||||
stability_score_thresh (`float`, *optional*, defaults to 0.95):
|
||||
A filtering threshold in `[0,1]`, using the stability of the mask under changes to the cutoff used to
|
||||
binarize the model's mask predictions.
|
||||
stability_score_offset (`int`, *optional*, defaults to 1):
|
||||
The amount to shift the cutoff when calculated the stability score.
|
||||
crops_nms_thresh (`float`, *optional*, defaults to 0.7):
|
||||
The box IoU cutoff used by non-maximal suppression to filter duplicate masks.
|
||||
crops_n_layers (`int`, *optional*, defaults to 0):
|
||||
If `crops_n_layers>0`, mask prediction will be run again on crops of the image. Sets the number of
|
||||
layers to run, where each layer has 2**i_layer number of image crops.
|
||||
crop_overlap_ratio (`float`, *optional*, defaults to `512 / 1500`):
|
||||
Sets the degree to which crops overlap. In the first crop layer, crops will overlap by this fraction of
|
||||
the image length. Later layers with more crops scale down this overlap.
|
||||
crop_n_points_downscale_factor (`int`, *optional*, defaults to `1`):
|
||||
The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||
|
||||
Return:
|
||||
`Dict`: A dictionary with the following keys:
|
||||
- **mask** (`PIL.Image`) -- A binary mask of the detected object as a PIL Image of shape `(width,
|
||||
height)` of the original image. Returns a mask filled with zeros if no object is found.
|
||||
- **score** (*optional* `float`) -- Optionally, when the model is capable of estimating a confidence of
|
||||
the "object" described by the label and the mask.
|
||||
|
||||
"""
|
||||
return super().__call__(image, *args, num_workers=num_workers, batch_size=batch_size, **kwargs)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
image,
|
||||
points_per_batch=64,
|
||||
crops_n_layers: int = 0,
|
||||
crop_overlap_ratio: float = 512 / 1500,
|
||||
points_per_crop: Optional[int] = 32,
|
||||
crop_n_points_downscale_factor: Optional[int] = 1,
|
||||
):
|
||||
image = load_image(image)
|
||||
target_size = self.image_processor.size["longest_edge"]
|
||||
crop_boxes, grid_points, cropped_images, input_labels = self.image_processor.generate_crop_boxes(
|
||||
image, target_size, crops_n_layers, crop_overlap_ratio, points_per_crop, crop_n_points_downscale_factor
|
||||
)
|
||||
model_inputs = self.image_processor(images=cropped_images, return_tensors="pt")
|
||||
|
||||
with self.device_placement():
|
||||
if self.framework == "pt":
|
||||
inference_context = self.get_inference_context()
|
||||
with inference_context():
|
||||
model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
|
||||
image_embeddings = self.model.get_image_embeddings(model_inputs.pop("pixel_values"))
|
||||
model_inputs["image_embeddings"] = image_embeddings
|
||||
|
||||
n_points = grid_points.shape[1]
|
||||
points_per_batch = points_per_batch if points_per_batch is not None else n_points
|
||||
|
||||
if points_per_batch <= 0:
|
||||
raise ValueError(
|
||||
"Cannot have points_per_batch<=0. Must be >=1 to returned batched outputs. "
|
||||
"To return all points at once, set points_per_batch to None"
|
||||
)
|
||||
|
||||
for i in range(0, n_points, points_per_batch):
|
||||
batched_points = grid_points[:, i : i + points_per_batch, :, :]
|
||||
labels = input_labels[:, i : i + points_per_batch]
|
||||
is_last = i == n_points - points_per_batch
|
||||
yield {
|
||||
"input_points": batched_points,
|
||||
"input_labels": labels,
|
||||
"input_boxes": crop_boxes,
|
||||
"is_last": is_last,
|
||||
**model_inputs,
|
||||
}
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
model_inputs,
|
||||
pred_iou_thresh=0.88,
|
||||
stability_score_thresh=0.95,
|
||||
mask_threshold=0,
|
||||
stability_score_offset=1,
|
||||
):
|
||||
input_boxes = model_inputs.pop("input_boxes")
|
||||
is_last = model_inputs.pop("is_last")
|
||||
original_sizes = model_inputs.pop("original_sizes").tolist()
|
||||
reshaped_input_sizes = model_inputs.pop("reshaped_input_sizes").tolist()
|
||||
|
||||
model_outputs = self.model(**model_inputs)
|
||||
|
||||
# post processing happens here in order to avoid CPU GPU copies of ALL the masks
|
||||
low_resolution_masks = model_outputs["pred_masks"]
|
||||
masks = self.image_processor.post_process_masks(
|
||||
low_resolution_masks, original_sizes, reshaped_input_sizes, mask_threshold, binarize=False
|
||||
)
|
||||
iou_scores = model_outputs["iou_scores"]
|
||||
masks, iou_scores, boxes = self.image_processor.filter_masks(
|
||||
masks[0],
|
||||
iou_scores[0],
|
||||
original_sizes[0],
|
||||
input_boxes[0],
|
||||
pred_iou_thresh,
|
||||
stability_score_thresh,
|
||||
mask_threshold,
|
||||
stability_score_offset,
|
||||
)
|
||||
return {
|
||||
"masks": masks,
|
||||
"is_last": is_last,
|
||||
"boxes": boxes,
|
||||
"iou_scores": iou_scores,
|
||||
}
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
model_outputs,
|
||||
output_rle_mask=False,
|
||||
output_bboxes_mask=False,
|
||||
crops_nms_thresh=0.7,
|
||||
):
|
||||
all_scores = []
|
||||
all_masks = []
|
||||
all_boxes = []
|
||||
for model_output in model_outputs:
|
||||
all_scores.append(model_output.pop("iou_scores"))
|
||||
all_masks.extend(model_output.pop("masks"))
|
||||
all_boxes.append(model_output.pop("boxes"))
|
||||
|
||||
all_scores = torch.cat(all_scores)
|
||||
all_boxes = torch.cat(all_boxes)
|
||||
output_masks, iou_scores, rle_mask, bounding_boxes = self.image_processor.post_process_for_mask_generation(
|
||||
all_masks, all_scores, all_boxes, crops_nms_thresh
|
||||
)
|
||||
|
||||
extra = defaultdict(list)
|
||||
for output in model_outputs:
|
||||
for k, v in output.items():
|
||||
extra[k].append(v)
|
||||
|
||||
optional = {}
|
||||
if output_rle_mask:
|
||||
optional["rle_mask"] = rle_mask
|
||||
|
||||
if output_bboxes_mask:
|
||||
optional["bounding_boxes"] = bounding_boxes
|
||||
|
||||
return {"masks": output_masks, "scores": iou_scores, **optional, **extra}
|
||||
Reference in New Issue
Block a user