From bd9d1ddf417f33e650da7e6a9e405573609dc673 Mon Sep 17 00:00:00 2001 From: Asif Ajrof Date: Fri, 31 May 2024 16:34:29 +0600 Subject: [PATCH] Update sam.md (#31130) `mask` variable is not defined. probably a writing mistake. it should be `segmentation_map`. `segmentation_map` should be a `1` channel image rather than `RGB`. [on a different note, the `mask_url` is the same as `raw_image`. could provide a better example. --- docs/source/en/model_doc/sam.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/model_doc/sam.md b/docs/source/en/model_doc/sam.md index 2fc06193a7..12a87eb5bc 100644 --- a/docs/source/en/model_doc/sam.md +++ b/docs/source/en/model_doc/sam.md @@ -81,10 +81,10 @@ processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") mask_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" -segmentation_map = Image.open(requests.get(mask_url, stream=True).raw).convert("RGB") +segmentation_map = Image.open(requests.get(mask_url, stream=True).raw).convert("1") input_points = [[[450, 600]]] # 2D location of a window in the image -inputs = processor(raw_image, input_points=input_points, segmentation_maps=mask, return_tensors="pt").to(device) +inputs = processor(raw_image, input_points=input_points, segmentation_maps=segmentation_map, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs)