Add segmentation map processing to SAM Image Processor (#27463)

* add segmentation map processing to sam image processor

* fixup

* add tests

* reshaped_input_size is shape before padding

* update tests for size/shape outputs

* fixup

* add code snippet to docs

* Update docs/source/en/model_doc/sam.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Add missing backticks

* add `segmentation_maps` as arg for SamProcessor.__call__()

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Rosie Wood
2024-01-08 16:40:36 +00:00
committed by GitHub
parent 2272ab57a9
commit 73c88012b7
4 changed files with 291 additions and 57 deletions

View File

@@ -66,6 +66,34 @@ masks = processor.image_processor.post_process_masks(
scores = outputs.iou_scores
```
You can also process your own masks alongside the input images in the processor to be passed to the model.
```python
import torch
from PIL import Image
import requests
from transformers import SamModel, SamProcessor
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
mask_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
segmentation_map = Image.open(requests.get(mask_url, stream=True).raw).convert("RGB")
input_points = [[[450, 600]]] # 2D location of a window in the image
inputs = processor(raw_image, input_points=input_points, segmentation_maps=mask, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
)
scores = outputs.iou_scores
```
Resources:
- [Demo notebook](https://github.com/huggingface/notebooks/blob/main/examples/segment_anything.ipynb) for using the model.