From 65e940208c38ebf82b5d7a2441eec361d2c968b1 Mon Sep 17 00:00:00 2001 From: sushmanth reddy <73489688+sushmanthreddy@users.noreply.github.com> Date: Mon, 28 Apr 2025 22:37:09 +0530 Subject: [PATCH] Samhq model addition (#35147) * added the configuartion for sam_hq * added the modeelling for sam_hq * added the sam hq mask decoder with hq features * added the code for the samhq * added the code for the samhq * added the code for the samhq * Delete src/transformers/models/sam_hq/modelling_sam_hq.py * added the code for the samhq * added the code for the samhq * added the chnages for the modeelling * added the code for sam hq for image processing * added code for the sam hq model * added the required changes * added the changes * added the key mappings for the sam hq * adding the working code of samhq * added the required files * adding the pt object * added the push to hub account * added the args for the sam maks decoder * added the args for the sam hq vision config * aded the some more documentation * removed the unecessary spaces * all required chnages * removed the image processor * added the required file * added the changes for the checkcopies * added the code for modular file * added the changes for the __init file * added the code for the interm embeds * added the code for sam hq * added the changes for modular file * added the test file * added the changes required * added the changes required * added the code for the * added the cl errors * added the changes * added the required changes * added the some code * added the code for the removing image processor * added the test dimensins * added the code for the removing extra used variables * added the code for modeluar file hf_mlp for a better name * removed abbrevaation in core functionality * removed abbrevaation in core functionality * .contiguous() method is often used to ensure that the tensor is stored in a contiguous block of memory * added the code which is after make fixup * added some test for the intermediate embeddings test * added the code for the torch support in sam hq * added the code for the updated modular file * added the changes for documentations as mentioned * removed the heading * add the changes for the code * first mentioned issue resolved * added the changes code to processor * added the easy loading to init file * added the changes to code * added the code to changes * added the code to work * added the code for sam hq * added the code for sam hq * added the code for the point pad value * added the small test for the image embeddings and intermediate embedding * added the code * added the code * added the code for the tests * added the code * added ythe code for the processor file * added the code * added the code * added the code * added the code * added the code * added the code for tests and some checks * added some code * added the code * added the code * added some code * added some code * added the changes for required * added the code * added the code * added the code * added the code * added the code * added the code * added the code * added the code * added the code * added the code * added some changes * added some changes * removed spaces and quality checks * added some code * added some code * added some code * added code quality checks * added the checks for quality checks * addded some code which fixes test_inference_mask_generation_no_point * added code for the test_inference_mask_generation_one_point_one_bb * added code for the test_inference_mask_generation_one_point_one_bb_zero * added code for the test_inference_mask_generation_one_box * added some code in modelling for testing * added some code which sort maks with high score * added some code * added some code * added some code for the move KEYS_TO_MODIFY_MAPPING * added some code for the unsqueeze removal * added some code for the unsqueeze removal * added some code * added some code * add some code * added some code * added some code * added some testign values changed * added changes to code in sam hq for readbility purpose * added pre commit checks * added the fix samvisionmodel for compatibilty * added the changes made on sam by cyyever * fixed the tests for samhq * added some the code * added some code related to init file issue during merge conflicts * remobved the merge conflicts * added changes mentioned by aruther and mobap * added changes mentioned by aruther and mobap * solving quality checks * added the changes for input clearly * added the changes * added changes in mask generation file rgearding model inputs and sam hq quargs in processor file * added changes in processor file * added the Setup -> setupclass conversion * added the code mentioned for processor * added changes for the code * added some code * added some code * added some code --------- Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/sam_hq.md | 127 ++ src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 5 + .../models/auto/image_processing_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 8 + .../models/auto/processing_auto.py | 1 + src/transformers/models/sam_hq/__init__.py | 28 + .../models/sam_hq/configuration_sam_hq.py | 315 +++ .../models/sam_hq/convert_samhq_to_hf.py | 277 +++ .../models/sam_hq/modeling_sam_hq.py | 1793 +++++++++++++++++ .../models/sam_hq/modular_sam_hq.py | 737 +++++++ .../models/sam_hq/processing_samhq.py | 330 +++ src/transformers/pipelines/mask_generation.py | 12 +- tests/models/sam_hq/__init__.py | 0 tests/models/sam_hq/test_modeling_sam_hq.py | 1116 ++++++++++ tests/models/sam_hq/test_processor_samhq.py | 167 ++ utils/check_config_attributes.py | 2 + utils/check_copies.py | 1 + utils/check_docstrings.py | 2 + utils/check_repo.py | 1 + utils/not_doctested.txt | 1 + 22 files changed, 4926 insertions(+), 1 deletion(-) create mode 100644 docs/source/en/model_doc/sam_hq.md create mode 100644 src/transformers/models/sam_hq/__init__.py create mode 100644 src/transformers/models/sam_hq/configuration_sam_hq.py create mode 100644 src/transformers/models/sam_hq/convert_samhq_to_hf.py create mode 100644 src/transformers/models/sam_hq/modeling_sam_hq.py create mode 100644 src/transformers/models/sam_hq/modular_sam_hq.py create mode 100644 src/transformers/models/sam_hq/processing_samhq.py create mode 100644 tests/models/sam_hq/__init__.py create mode 100644 tests/models/sam_hq/test_modeling_sam_hq.py create mode 100644 tests/models/sam_hq/test_processor_samhq.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 6a9a50c1f7..171ca01f65 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1017,6 +1017,8 @@ title: Qwen2VL - local: model_doc/sam title: Segment Anything + - local: model_doc/sam_hq + title: Segment Anything High Quality - local: model_doc/shieldgemma2 title: ShieldGemma2 - local: model_doc/siglip diff --git a/docs/source/en/model_doc/sam_hq.md b/docs/source/en/model_doc/sam_hq.md new file mode 100644 index 0000000000..8c60b86117 --- /dev/null +++ b/docs/source/en/model_doc/sam_hq.md @@ -0,0 +1,127 @@ +# SAM-HQ + +## Overview + +SAM-HQ (High-Quality Segment Anything Model) was proposed in [Segment Anything in High Quality](https://arxiv.org/pdf/2306.01567.pdf) by Lei Ke, Mingqiao Ye, Martin Danelljan, Yifan Liu, Yu-Wing Tai, Chi-Keung Tang, Fisher Yu. + +The model is an enhancement to the original SAM model that produces significantly higher quality segmentation masks while maintaining SAM's original promptable design, efficiency, and zero-shot generalizability. + +![example image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-output.png) + + +SAM-HQ introduces several key improvements over the original SAM model: + +1. High-Quality Output Token: A learnable token injected into SAM's mask decoder for higher quality mask prediction +2. Global-local Feature Fusion: Combines features from different stages of the model for improved mask details +3. Training Data: Uses a carefully curated dataset of 44K high-quality masks instead of SA-1B +4. Efficiency: Adds only 0.5% additional parameters while significantly improving mask quality +5. Zero-shot Capability: Maintains SAM's strong zero-shot performance while improving accuracy + +The abstract from the paper is the following: + +*The recent Segment Anything Model (SAM) represents a big leap in scaling up segmentation models, allowing for powerful zero-shot capabilities and flexible prompting. Despite being trained with 1.1 billion masks, SAM's mask prediction quality falls short in many cases, particularly when dealing with objects that have intricate structures. We propose HQ-SAM, equipping SAM with the ability to accurately segment any object, while maintaining SAM's original promptable design, efficiency, and zero-shot generalizability. Our careful design reuses and preserves the pre-trained model weights of SAM, while only introducing minimal additional parameters and computation. We design a learnable High-Quality Output Token, which is injected into SAM's mask decoder and is responsible for predicting the high-quality mask. Instead of only applying it on mask-decoder features, we first fuse them with early and final ViT features for improved mask details. To train our introduced learnable parameters, we compose a dataset of 44K fine-grained masks from several sources. HQ-SAM is only trained on the introduced dataset of 44k masks, which takes only 4 hours on 8 GPUs.* + +Tips: + +- SAM-HQ produces higher quality masks than the original SAM model, particularly for objects with intricate structures and fine details +- The model predicts binary masks with more accurate boundaries and better handling of thin structures +- Like SAM, the model performs better with input 2D points and/or input bounding boxes +- You can prompt multiple points for the same image and predict a single high-quality mask +- The model maintains SAM's zero-shot generalization capabilities +- SAM-HQ only adds ~0.5% additional parameters compared to SAM +- Fine-tuning the model is not supported yet + +This model was contributed by [sushmanth](https://huggingface.co/sushmanth). +The original code can be found [here](https://github.com/SysCV/SAM-HQ). + +Below is an example on how to run mask generation given an image and a 2D point: + +```python +import torch +from PIL import Image +import requests +from transformers import SamHQModel, SamHQProcessor + +device = "cuda" if torch.cuda.is_available() else "cpu" +model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device) +processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + +img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" +raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") +input_points = [[[450, 600]]] # 2D location of a window in the image + +inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to(device) +with torch.no_grad(): + outputs = model(**inputs) + +masks = processor.image_processor.post_process_masks( + outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() +) +scores = outputs.iou_scores +``` + +You can also process your own masks alongside the input images in the processor to be passed to the model: + +```python +import torch +from PIL import Image +import requests +from transformers import SamHQModel, SamHQProcessor + +device = "cuda" if torch.cuda.is_available() else "cpu" +model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device) +processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + +img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" +raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") +mask_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" +segmentation_map = Image.open(requests.get(mask_url, stream=True).raw).convert("1") +input_points = [[[450, 600]]] # 2D location of a window in the image + +inputs = processor(raw_image, input_points=input_points, segmentation_maps=segmentation_map, return_tensors="pt").to(device) +with torch.no_grad(): + outputs = model(**inputs) + +masks = processor.image_processor.post_process_masks( + outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() +) +scores = outputs.iou_scores +``` + + +## Resources + +A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with SAM-HQ: + +- Demo notebook for using the model (coming soon) +- Paper implementation and code: [SAM-HQ GitHub Repository](https://github.com/SysCV/SAM-HQ) + +## SamHQConfig + +[[autodoc]] SamHQConfig + +## SamHQVisionConfig + +[[autodoc]] SamHQVisionConfig + +## SamHQMaskDecoderConfig + +[[autodoc]] SamHQMaskDecoderConfig + +## SamHQPromptEncoderConfig + +[[autodoc]] SamHQPromptEncoderConfig + +## SamHQProcessor + +[[autodoc]] SamHQProcessor + +## SamHQVisionModel + +[[autodoc]] SamHQVisionModel + + +## SamHQModel + +[[autodoc]] SamHQModel + - forward \ No newline at end of file diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 8b439b47c0..5feb76f1a1 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -254,6 +254,7 @@ if TYPE_CHECKING: from .rt_detr_v2 import * from .rwkv import * from .sam import * + from .sam_hq import * from .seamless_m4t import * from .seamless_m4t_v2 import * from .segformer import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 9da4e5c460..bbef0bc920 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -286,6 +286,8 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("rt_detr_v2", "RTDetrV2Config"), ("rwkv", "RwkvConfig"), ("sam", "SamConfig"), + ("sam_hq", "SamHQConfig"), + ("sam_hq_vision_model", "SamHQVisionConfig"), ("sam_vision_model", "SamVisionConfig"), ("seamless_m4t", "SeamlessM4TConfig"), ("seamless_m4t_v2", "SeamlessM4Tv2Config"), @@ -658,6 +660,8 @@ MODEL_NAMES_MAPPING = OrderedDict( ("rt_detr_v2", "RT-DETRv2"), ("rwkv", "RWKV"), ("sam", "SAM"), + ("sam_hq", "SAM-HQ"), + ("sam_hq_vision_model", "SamHQVisionModel"), ("sam_vision_model", "SamVisionModel"), ("seamless_m4t", "SeamlessM4T"), ("seamless_m4t_v2", "SeamlessM4Tv2"), @@ -807,6 +811,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict( ("qwen2_5_vl_text", "qwen2_5_vl"), ("qwen2_vl_text", "qwen2_vl"), ("sam_vision_model", "sam"), + ("sam_hq_vision_model", "sam_hq"), ("llama4_text", "llama4"), ("blip_2_qformer", "blip_2"), ] diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 79b13f1f4a..ee941eed35 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -141,6 +141,7 @@ else: ("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")), ("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")), ("sam", ("SamImageProcessor",)), + ("sam_hq", ("SamImageProcessor",)), ("segformer", ("SegformerImageProcessor",)), ("seggpt", ("SegGptImageProcessor",)), ("shieldgemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 889d3658ed..fe83a8d1b9 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -257,6 +257,8 @@ MODEL_MAPPING_NAMES = OrderedDict( ("rt_detr_v2", "RTDetrV2Model"), ("rwkv", "RwkvModel"), ("sam", "SamModel"), + ("sam_hq", "SamHQModel"), + ("sam_hq_vision_model", "SamHQVisionModel"), ("sam_vision_model", "SamVisionModel"), ("seamless_m4t", "SeamlessM4TModel"), ("seamless_m4t_v2", "SeamlessM4Tv2Model"), @@ -1495,6 +1497,12 @@ MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict( ] ) +MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict( + [ + ("sam_hq", "SamHQModel"), + ] +) + MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES = OrderedDict( [ diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 7060b6125d..d49301cd5f 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -104,6 +104,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict( ("qwen2_audio", "Qwen2AudioProcessor"), ("qwen2_vl", "Qwen2VLProcessor"), ("sam", "SamProcessor"), + ("sam_hq", "SamHQProcessor"), ("seamless_m4t", "SeamlessM4TProcessor"), ("sew", "Wav2Vec2Processor"), ("sew-d", "Wav2Vec2Processor"), diff --git a/src/transformers/models/sam_hq/__init__.py b/src/transformers/models/sam_hq/__init__.py new file mode 100644 index 0000000000..8074c56727 --- /dev/null +++ b/src/transformers/models/sam_hq/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_sam_hq import * + from .modeling_sam_hq import * + from .processing_samhq import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/sam_hq/configuration_sam_hq.py b/src/transformers/models/sam_hq/configuration_sam_hq.py new file mode 100644 index 0000000000..49062efc68 --- /dev/null +++ b/src/transformers/models/sam_hq/configuration_sam_hq.py @@ -0,0 +1,315 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/sam_hq/modular_sam_hq.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_sam_hq.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig + + +class SamHQPromptEncoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SamHQPromptEncoderModel`].The [`SamHQPromptEncoderModel`] + module is used to encode the input 2D points and bounding boxes. Instantiating a configuration defaults will yield a + similar configuration to that of the SAM_HQ model. The configuration is used to store the configuration of the model. + [Uminosachi/sam-hq](https://huggingface.co/Uminosachi/sam-hq) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model's output.Read the documentation from + [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden states. + image_size (`int`, *optional*, defaults to 1024): + The expected output resolution of the image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + mask_input_channels (`int`, *optional*, defaults to 16): + The number of channels to be fed to the `MaskDecoder` module. + num_point_embeddings (`int`, *optional*, defaults to 4): + The number of point embeddings to be used. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the encoder and pooler. + """ + + base_config_key = "prompt_encoder_config" + + def __init__( + self, + hidden_size=256, + image_size=1024, + patch_size=16, + mask_input_channels=16, + num_point_embeddings=4, + hidden_act="gelu", + layer_norm_eps=1e-6, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.image_size = image_size + self.patch_size = patch_size + self.image_embedding_size = image_size // patch_size + self.mask_input_channels = mask_input_channels + self.num_point_embeddings = num_point_embeddings + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + + +class SamHQVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SamHQVisionModel`]. It is used to instantiate a SAM_HQ + vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration + defaults will yield a similar configuration to that of the SAM_HQ ViT-h + [facebook/sam_hq-vit-huge](https://huggingface.co/facebook/sam_hq-vit-huge) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + output_channels (`int`, *optional*, defaults to 256): + Dimensionality of the output channels in the Patch Encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input image. + image_size (`int`, *optional*, defaults to 1024): + Expected resolution. Target size of the resized input image. + patch_size (`int`, *optional*, defaults to 16): + Size of the patches to be extracted from the input image. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 1e-10): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to query, key, value projections. + mlp_ratio (`float`, *optional*, defaults to 4.0): + Ratio of mlp hidden dim to embedding dim. + use_abs_pos (`bool`, *optional*, defaults to `True`): + Whether to use absolute position embedding. + use_rel_pos (`bool`, *optional*, defaults to `True`): + Whether to use relative position embedding. + window_size (`int`, *optional*, defaults to 14): + Window size for relative position. + global_attn_indexes (`List[int]`, *optional*, defaults to `[2, 5, 8, 11]`): + The indexes of the global attention layers. + num_pos_feats (`int`, *optional*, defaults to 128): + The dimensionality of the position embedding. + mlp_dim (`int`, *optional*): + The dimensionality of the MLP layer in the Transformer encoder. If `None`, defaults to `mlp_ratio * + hidden_size`. + + Example: + + ```python + >>> from transformers import ( + ... SamHQVisionConfig, + ... SamHQVisionModel, + ... ) + + >>> # Initializing a SamHQVisionConfig with `"facebook/sam_hq-vit-huge"` style configuration + >>> configuration = SamHQVisionConfig() + + >>> # Initializing a SamHQVisionModel (with random weights) from the `"facebook/sam_hq-vit-huge"` style configuration + >>> model = SamHQVisionModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + base_config_key = "vision_config" + model_type = "sam_hq_vision_model" + + def __init__( + self, + hidden_size=768, + output_channels=256, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=1024, + patch_size=16, + hidden_act="gelu", + layer_norm_eps=1e-06, + attention_dropout=0.0, + initializer_range=1e-10, + qkv_bias=True, + mlp_ratio=4.0, + use_abs_pos=True, + use_rel_pos=True, + window_size=14, + global_attn_indexes=[2, 5, 8, 11], + num_pos_feats=128, + mlp_dim=None, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.output_channels = output_channels + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.qkv_bias = qkv_bias + self.mlp_ratio = mlp_ratio + self.use_abs_pos = use_abs_pos + self.use_rel_pos = use_rel_pos + self.window_size = window_size + self.global_attn_indexes = global_attn_indexes + self.num_pos_feats = num_pos_feats + self.mlp_dim = int(hidden_size * mlp_ratio) if mlp_dim is None else mlp_dim + + +class SamHQMaskDecoderConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`SamHQMaskDecoder`]. It is used to instantiate a SAM_HQ + mask decoder to the specified arguments, defining the model architecture. Instantiating a configuration defaults + will yield a similar configuration to that of the SAM_HQ-vit-h + [facebook/sam_hq-vit-huge](https://huggingface.co/facebook/sam_hq-vit-huge) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden states. + hidden_act (`str`, *optional*, defaults to `"relu"`): + The non-linear activation function used inside the `SamHQMaskDecoder` module. + mlp_dim (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 2): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + attention_downsample_rate (`int`, *optional*, defaults to 2): + The downsampling rate of the attention layer. + num_multimask_outputs (`int`, *optional*, defaults to 3): + The number of outputs from the `SamHQMaskDecoder` module. In the Segment Anything paper, this is set to 3. + iou_head_depth (`int`, *optional*, defaults to 3): + The number of layers in the IoU head module. + iou_head_hidden_dim (`int`, *optional*, defaults to 256): + The dimensionality of the hidden states in the IoU head module. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + + + vit_dim (`int`, *optional*, defaults to 768): + Dimensionality of the Vision Transformer (ViT) used in the `SamHQMaskDecoder` module. + """ + + base_config_key = "mask_decoder_config" + + def __init__( + self, + hidden_size=256, + hidden_act="relu", + mlp_dim=2048, + num_hidden_layers=2, + num_attention_heads=8, + attention_downsample_rate=2, + num_multimask_outputs=3, + iou_head_depth=3, + iou_head_hidden_dim=256, + layer_norm_eps=1e-6, + vit_dim=768, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.mlp_dim = mlp_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.attention_downsample_rate = attention_downsample_rate + self.num_multimask_outputs = num_multimask_outputs + self.iou_head_depth = iou_head_depth + self.iou_head_hidden_dim = iou_head_hidden_dim + self.layer_norm_eps = layer_norm_eps + self.vit_dim = vit_dim + + +class SamHQConfig(PretrainedConfig): + r""" + [`SamHQConfig`] is the configuration class to store the configuration of a [`SamHQModel`]. It is used to instantiate a + SAM-HQ model according to the specified arguments, defining the vision model, prompt-encoder model and mask decoder + configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the + SAM-HQ-ViT-H [sushmanth/sam_hq_vit_h](https://huggingface.co/sushmanth/sam_hq_vit_h) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (Union[`dict`, `SamHQVisionConfig`], *optional*): + Dictionary of configuration options used to initialize [`SamHQVisionConfig`]. + prompt_encoder_config (Union[`dict`, `SamHQPromptEncoderConfig`], *optional*): + Dictionary of configuration options used to initialize [`SamHQPromptEncoderConfig`]. + mask_decoder_config (Union[`dict`, `SamHQMaskDecoderConfig`], *optional*): + Dictionary of configuration options used to initialize [`SamHQMaskDecoderConfig`]. + kwargs (*optional*): + Dictionary of keyword arguments. + """ + + model_type = "sam_hq" + sub_configs = { + "prompt_encoder_config": SamHQPromptEncoderConfig, + "mask_decoder_config": SamHQMaskDecoderConfig, + "vision_config": SamHQVisionConfig, + } + + def __init__( + self, + vision_config=None, + prompt_encoder_config=None, + mask_decoder_config=None, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + vision_config = vision_config if vision_config is not None else {} + prompt_encoder_config = prompt_encoder_config if prompt_encoder_config is not None else {} + mask_decoder_config = mask_decoder_config if mask_decoder_config is not None else {} + + if isinstance(vision_config, SamHQVisionConfig): + vision_config = vision_config.to_dict() + if isinstance(prompt_encoder_config, SamHQPromptEncoderConfig): + prompt_encoder_config = prompt_encoder_config.to_dict() + if isinstance(mask_decoder_config, SamHQMaskDecoderConfig): + mask_decoder_config = mask_decoder_config.to_dict() + + self.vision_config = SamHQVisionConfig(**vision_config) + self.prompt_encoder_config = SamHQPromptEncoderConfig(**prompt_encoder_config) + self.mask_decoder_config = SamHQMaskDecoderConfig(**mask_decoder_config) + self.initializer_range = initializer_range + + +__all__ = ["SamHQVisionConfig", "SamHQMaskDecoderConfig", "SamHQPromptEncoderConfig", "SamHQConfig"] diff --git a/src/transformers/models/sam_hq/convert_samhq_to_hf.py b/src/transformers/models/sam_hq/convert_samhq_to_hf.py new file mode 100644 index 0000000000..366b84abfc --- /dev/null +++ b/src/transformers/models/sam_hq/convert_samhq_to_hf.py @@ -0,0 +1,277 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Convert SAM-HQ checkpoints from the original repository. + +URL: https://github.com/SysCV/sam-hq + +""" + +import argparse + +import numpy as np +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import SamHQConfig, SamHQModel, SamHQProcessor, SamHQVisionConfig, SamImageProcessor + + +def get_config(model_name): + if "sam_hq_vit_b" in model_name: + vision_config = SamHQVisionConfig() + vit_dim = 768 # Base model dimension + elif "sam_hq_vit_l" in model_name: + vision_config = SamHQVisionConfig( + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + global_attn_indexes=[5, 11, 17, 23], + ) + vit_dim = 1024 # Large model dimension + elif "sam_hq_vit_h" in model_name: + vision_config = SamHQVisionConfig( + hidden_size=1280, + num_hidden_layers=32, + num_attention_heads=16, + global_attn_indexes=[7, 15, 23, 31], + ) + vit_dim = 1280 # Huge model dimension + + # Create mask decoder config with appropriate vit_dim + mask_decoder_config = {"vit_dim": vit_dim} + + config = SamHQConfig( + vision_config=vision_config, + mask_decoder_config=mask_decoder_config, + ) + + return config + + +KEYS_TO_MODIFY_MAPPING = { + "iou_prediction_head.layers.0": "iou_prediction_head.proj_in", + "iou_prediction_head.layers.1": "iou_prediction_head.layers.0", + "iou_prediction_head.layers.2": "iou_prediction_head.proj_out", + "mask_decoder.output_upscaling.0": "mask_decoder.upscale_conv1", + "mask_decoder.output_upscaling.1": "mask_decoder.upscale_layer_norm", + "mask_decoder.output_upscaling.3": "mask_decoder.upscale_conv2", + "mask_downscaling.0": "mask_embed.conv1", + "mask_downscaling.1": "mask_embed.layer_norm1", + "mask_downscaling.3": "mask_embed.conv2", + "mask_downscaling.4": "mask_embed.layer_norm2", + "mask_downscaling.6": "mask_embed.conv3", + "point_embeddings": "point_embed", + "pe_layer.positional_encoding_gaussian_matrix": "shared_embedding.positional_embedding", + "image_encoder": "vision_encoder", + "neck.0": "neck.conv1", + "neck.1": "neck.layer_norm1", + "neck.2": "neck.conv2", + "neck.3": "neck.layer_norm2", + "patch_embed.proj": "patch_embed.projection", + ".norm": ".layer_norm", + "blocks": "layers", + # HQ-specific mappings + "mask_decoder.hf_token": "mask_decoder.hq_token", + "mask_decoder.compress_vit_feat.0": "mask_decoder.compress_vit_conv1", + "mask_decoder.compress_vit_feat.1": "mask_decoder.compress_vit_norm", + "mask_decoder.compress_vit_feat.3": "mask_decoder.compress_vit_conv2", + "mask_decoder.embedding_encoder.0": "mask_decoder.encoder_conv1", + "mask_decoder.embedding_encoder.1": "mask_decoder.encoder_norm", + "mask_decoder.embedding_encoder.3": "mask_decoder.encoder_conv2", + "mask_decoder.embedding_maskfeature.0": "mask_decoder.mask_conv1", + "mask_decoder.embedding_maskfeature.1": "mask_decoder.mask_norm", + "mask_decoder.embedding_maskfeature.3": "mask_decoder.mask_conv2", + "mask_decoder.hf_mlp": "mask_decoder.hq_mask_mlp", + # Add patterns for the output_hypernetworks_mlps and hq_mask_mlp + "output_hypernetworks_mlps.0.layers.0": "output_hypernetworks_mlps.0.proj_in", + "output_hypernetworks_mlps.0.layers.1": "output_hypernetworks_mlps.0.layers.0", + "output_hypernetworks_mlps.0.layers.2": "output_hypernetworks_mlps.0.proj_out", + "output_hypernetworks_mlps.1.layers.0": "output_hypernetworks_mlps.1.proj_in", + "output_hypernetworks_mlps.1.layers.1": "output_hypernetworks_mlps.1.layers.0", + "output_hypernetworks_mlps.1.layers.2": "output_hypernetworks_mlps.1.proj_out", + "output_hypernetworks_mlps.2.layers.0": "output_hypernetworks_mlps.2.proj_in", + "output_hypernetworks_mlps.2.layers.1": "output_hypernetworks_mlps.2.layers.0", + "output_hypernetworks_mlps.2.layers.2": "output_hypernetworks_mlps.2.proj_out", + "output_hypernetworks_mlps.3.layers.0": "output_hypernetworks_mlps.3.proj_in", + "output_hypernetworks_mlps.3.layers.1": "output_hypernetworks_mlps.3.layers.0", + "output_hypernetworks_mlps.3.layers.2": "output_hypernetworks_mlps.3.proj_out", + "hq_mask_mlp.layers.0": "hq_mask_mlp.proj_in", + "hq_mask_mlp.layers.1": "hq_mask_mlp.layers.0", + "hq_mask_mlp.layers.2": "hq_mask_mlp.proj_out", +} + + +def replace_keys(state_dict): + model_state_dict = {} + state_dict.pop("pixel_mean", None) + state_dict.pop("pixel_std", None) + + # Process each key in the state dict + for key, value in state_dict.items(): + new_key = key + + # Apply static mappings from KEYS_TO_MODIFY_MAPPING + for key_to_modify, replacement in KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in new_key: + new_key = new_key.replace(key_to_modify, replacement) + + model_state_dict[new_key] = value + + # Add mapping for shared embedding for positional embedding + if "prompt_encoder.shared_embedding.positional_embedding" in model_state_dict: + model_state_dict["shared_image_embedding.positional_embedding"] = model_state_dict[ + "prompt_encoder.shared_embedding.positional_embedding" + ] + + # Special handling for IOU prediction head keys + # Check if we're missing the expected keys and have the converted ones instead + if ( + "mask_decoder.iou_prediction_head.layers.0.weight" not in model_state_dict + and "mask_decoder.iou_prediction_head.proj_in.weight" in model_state_dict + ): + # Copy the converted key back to the expected format + model_state_dict["mask_decoder.iou_prediction_head.layers.0.weight"] = model_state_dict[ + "mask_decoder.iou_prediction_head.proj_in.weight" + ] + model_state_dict["mask_decoder.iou_prediction_head.layers.0.bias"] = model_state_dict[ + "mask_decoder.iou_prediction_head.proj_in.bias" + ] + + return model_state_dict + + +def convert_sam_hq_checkpoint(model_name, checkpoint_path, pytorch_dump_folder, push_to_hub, hub_path): + config = get_config(model_name) + + state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + state_dict = replace_keys(state_dict) + + image_processor = SamImageProcessor() + processor = SamHQProcessor(image_processor=image_processor) + hf_model = SamHQModel(config) + hf_model.eval() + + device = "cuda" if torch.cuda.is_available() else "cpu" + + hf_model.load_state_dict(state_dict) + + hf_model = hf_model.to(device) + + # Test the model with a sample image + img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + + input_points = [[[500, 375]]] + input_labels = [[1]] + + # Basic test without prompts + inputs = processor(images=np.array(raw_image), return_tensors="pt").to(device) + + with torch.no_grad(): + hf_model(**inputs) + + if model_name == "sam_hq_vit_b": + inputs = processor( + images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(device) + + with torch.no_grad(): + hf_model(**inputs) + + elif model_name == "sam_hq_vit_h": + inputs = processor( + images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(device) + + with torch.no_grad(): + hf_model(**inputs) + + input_boxes = [[[75.0, 275.0, 1725.0, 850.0]]] + + inputs = processor(images=np.array(raw_image), input_boxes=input_boxes, return_tensors="pt").to(device) + + with torch.no_grad(): + hf_model(**inputs) + + input_points = [[[400, 650], [800, 650]]] + input_labels = [[1, 1]] + + inputs = processor( + images=np.array(raw_image), input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(device) + + with torch.no_grad(): + hf_model(**inputs) + + if pytorch_dump_folder is not None: + processor.save_pretrained(pytorch_dump_folder) + hf_model.save_pretrained(pytorch_dump_folder) + + if push_to_hub: + repo_id = f"{hub_path}/{model_name}" + processor.push_to_hub(repo_id) + hf_model.push_to_hub(repo_id) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + choices = ["sam_hq_vit_b", "sam_hq_vit_h", "sam_hq_vit_l"] + parser.add_argument( + "--model_name", + choices=choices, + type=str, + required=True, + help="Name of the SAM-HQ model to convert", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + required=False, + help="Path to the SAM-HQ checkpoint (.pth file)", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + type=str, + default=None, + help="Path to save the converted model", + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether to push the converted model to the hub", + ) + parser.add_argument( + "--hub_path", + type=str, + default="sushmanth", + help="Hugging Face Hub path where the model will be uploaded", + ) + + args = parser.parse_args() + + checkpoint_path = args.checkpoint_path + if checkpoint_path is None: + checkpoint_path = hf_hub_download("lkeab/hq-sam", f"{args.model_name}.pth") + + convert_sam_hq_checkpoint( + args.model_name, + checkpoint_path, + args.pytorch_dump_folder_path, + args.push_to_hub, + args.hub_path, + ) diff --git a/src/transformers/models/sam_hq/modeling_sam_hq.py b/src/transformers/models/sam_hq/modeling_sam_hq.py new file mode 100644 index 0000000000..21d9a60f2d --- /dev/null +++ b/src/transformers/models/sam_hq/modeling_sam_hq.py @@ -0,0 +1,1793 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/sam_hq/modular_sam_hq.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_sam_hq.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + logging, + replace_return_docstrings, +) +from .configuration_sam_hq import SamHQConfig, SamHQMaskDecoderConfig, SamHQPromptEncoderConfig, SamHQVisionConfig + + +logger = logging.get_logger(__name__) + + +@dataclass +class SamHQVisionEncoderOutput(ModelOutput): + """ + Base class for sam_hq vision model's outputs that also contains image embeddings obtained by applying the projection + layer to the pooler_output. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + intermediate_embeddings (`list(torch.FloatTensor)`, *optional*): + A list of intermediate embeddings collected from certain blocks within the model, typically those without + windowed attention. Each element in the list is of shape `(batch_size, sequence_length, hidden_size)`. + This is specific to SAM-HQ and not present in base SAM. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + intermediate_embeddings: Optional[List[torch.FloatTensor]] = None + + +@dataclass +class SamHQImageSegmentationOutput(ModelOutput): + """ + Base class for Segment-Anything model's output + + Args: + iou_scores (`torch.FloatTensor` of shape `(batch_size, num_masks)`): + The iou scores of the predicted masks. + pred_masks (`torch.FloatTensor` of shape `(batch_size, num_masks, height, width)`): + The predicted low resolutions masks. Needs to be post-processed by the processor + vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs. + vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + iou_scores: Optional[torch.FloatTensor] = None + pred_masks: Optional[torch.FloatTensor] = None + vision_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + vision_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + mask_decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +class SamHQPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values): + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values).permute(0, 2, 3, 1) + return embeddings + + +class SamHQMLPBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim) + self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.lin1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.lin2(hidden_states) + return hidden_states + + +class SamHQVisionAttention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__(self, config, window_size): + super().__init__() + input_size = ( + (config.image_size // config.patch_size, config.image_size // config.patch_size) + if window_size == 0 + else (window_size, window_size) + ) + + self.num_attention_heads = config.num_attention_heads + head_dim = config.hidden_size // config.num_attention_heads + self.scale = head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias) + self.proj = nn.Linear(config.hidden_size, config.hidden_size) + + self.use_rel_pos = config.use_rel_pos + if self.use_rel_pos: + if input_size is None: + raise ValueError("Input size must be provided if using relative positional encoding.") + + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + + Args: + q_size (int): + size of the query. + k_size (int): + size of key k. + rel_pos (`torch.Tensor`): + relative position embeddings (L, channel). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + def get_decomposed_rel_pos( + self, + query: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], + ) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py + + Args: + query (`torch.Tensor`): + query q in the attention layer with shape (batch_size, query_height * query_width, channel). + rel_pos_h (`torch.Tensor`): + relative position embeddings (Lh, channel) for height axis. + rel_pos_w (`torch.Tensor`): + relative position embeddings (Lw, channel) for width axis. + q_size (tuple): + spatial sequence size of query q with (query_height, query_width). + k_size (tuple): + spatial sequence size of key k with (key_height, key_width). + + Returns: + decomposed_rel_pos (`torch.Tensor`): + decomposed relative position embeddings. + """ + query_height, query_width = q_size + key_height, key_width = k_size + relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h) + relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w) + + batch_size, _, dim = query.shape + reshaped_query = query.reshape(batch_size, query_height, query_width, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height) + rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width) + + decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + + return decomposed_rel_pos + + def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (3, batch_size, nHead, height * width, channel) + qkv = ( + self.qkv(hidden_states) + .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + .permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (batch_size * nHead, height * width, channel) + query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) + + attn_weights = (query * self.scale) @ key.transpose(-2, -1) + + if self.use_rel_pos: + decomposed_rel_pos = self.get_decomposed_rel_pos( + query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + ) + decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights) + attn_weights = attn_weights + decomposed_rel_pos + + attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) + attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1) + + attn_output = self.proj(attn_output) + + if output_attentions: + outputs = (attn_output, attn_weights) + else: + outputs = (attn_output, None) + + return outputs + + +class SamHQVisionSdpaAttention(SamHQVisionAttention): + """ + Multi-head Attention block with relative position embeddings. + Using SDPA instead of the default attention. + """ + + def __init__(self, config, window_size): + super().__init__(config, window_size) + + def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: + if output_attentions: + logger.warning_once( + "`SamHQVisionSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "`output_attentions=True`. Falling back to the manual attention implementation, but " + "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + output_attentions=output_attentions, + ) + + batch_size, height, width, _ = hidden_states.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = ( + self.qkv(hidden_states) + .reshape(batch_size, height * width, 3, self.num_attention_heads, -1) + .permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (B * nHead, H * W, C) + query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0) + + attn_bias = None + if self.use_rel_pos: + decomposed_rel_pos = self.get_decomposed_rel_pos( + query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + ) + decomposed_rel_pos = decomposed_rel_pos.reshape( + batch_size, self.num_attention_heads, height * width, height * width + ) + attn_bias = decomposed_rel_pos + + query = query.view(batch_size, self.num_attention_heads, height * width, -1) + key = key.view(batch_size, self.num_attention_heads, height * width, -1) + value = value.view(batch_size, self.num_attention_heads, height * width, -1) + + attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias) + + attn_output = ( + attn_output.view(batch_size, self.num_attention_heads, height, width, -1) + .permute(0, 2, 3, 1, 4) + .reshape(batch_size, height, width, -1) + ) + + attn_output = self.proj(attn_output) + + return attn_output, None + + +SAM_HQ_VISION_ATTENTION_CLASSES = { + "eager": SamHQVisionAttention, + "sdpa": SamHQVisionSdpaAttention, +} + + +class SamHQVisionLayer(nn.Module): + def __init__(self, config, window_size): + super().__init__() + self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attn = SAM_HQ_VISION_ATTENTION_CLASSES[config._attn_implementation](config, window_size) + self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SamHQMLPBlock(config) + self.window_size = window_size + + def window_partition(self, hidden_states: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Args: + Partition into non-overlapping windows with padding if needed. + hidden_states (tensor): input tokens with [batch_size, height, width, channel]. window_size (int): window + size. + + Returns: + windows: windows after partition with [batch_size * num_windows, window_size, window_size, channel]. + (pad_height, pad_width): padded height and width before partition + """ + batch_size, height, width, channel = hidden_states.shape + + pad_h = (window_size - height % window_size) % window_size + pad_w = (window_size - width % window_size) % window_size + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h)) + pad_height, pad_width = height + pad_h, width + pad_w + + hidden_states = hidden_states.reshape( + batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel + ) + windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, window_size, window_size, channel) + return windows, (pad_height, pad_width) + + def window_unpartition( + self, windows: torch.Tensor, window_size: int, padding_shape: Tuple[int, int], original_shape: Tuple[int, int] + ) -> torch.Tensor: + """ + Args: + Window unpartition into original sequences and removing padding. + hidden_states (tensor): + input tokens with [batch_size * num_windows, window_size, window_size, channel]. + window_size (int): + window size. + padding_shape (Tuple): + padded height and width (pad_height, pad_width). + original_shape (Tuple): original height and width (height, width) before padding. + + Returns: + hidden_states: unpartitioned sequences with [batch_size, height, width, channel]. + """ + pad_height, pad_width = padding_shape + height, width = original_shape + batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size) + hidden_states = windows.reshape( + batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1 + ) + hidden_states = ( + hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1) + ) + + hidden_states = hidden_states[:, :height, :width, :].contiguous() + return hidden_states + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + # Window partition + if self.window_size > 0: + height, width = hidden_states.shape[1], hidden_states.shape[2] + hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size) + + hidden_states, attn_weights = self.attn( + hidden_states=hidden_states, + output_attentions=output_attentions, + ) + # Reverse window partition + if self.window_size > 0: + hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width)) + + hidden_states = residual + hidden_states + layernorm_output = self.layer_norm2(hidden_states) + hidden_states = hidden_states + self.mlp(layernorm_output) + + outputs = (hidden_states,) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SamHQVisionNeck(nn.Module): + def __init__(self, config: SamHQVisionConfig): + super().__init__() + self.config = config + + self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False) + self.layer_norm1 = SamHQLayerNorm(config.output_channels, data_format="channels_first") + self.conv2 = nn.Conv2d(config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False) + self.layer_norm2 = SamHQLayerNorm(config.output_channels, data_format="channels_first") + + def forward(self, hidden_states): + hidden_states = hidden_states.permute(0, 3, 1, 2) + hidden_states = self.conv1(hidden_states) + hidden_states = self.layer_norm1(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + return hidden_states + + +class SamHQVisionEncoder(nn.Module): + def __init__(self, config: SamHQVisionConfig): + super().__init__() + self.config = config + self.image_size = config.image_size + + self.patch_embed = SamHQPatchEmbeddings(config) + + self.pos_embed = None + if config.use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros( + 1, + config.image_size // config.patch_size, + config.image_size // config.patch_size, + config.hidden_size, + ) + ) + + self.layers = nn.ModuleList() + for i in range(config.num_hidden_layers): + layer = SamHQVisionLayer( + config, + window_size=config.window_size if i not in config.global_attn_indexes else 0, + ) + self.layers.append(layer) + + self.neck = SamHQVisionNeck(config) + + self.gradient_checkpointing = False + + def get_input_embeddings(self): + return self.patch_embed + + @can_return_tuple + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SamHQVisionEncoderOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.patch_embed(pixel_values) + if self.pos_embed is not None: + hidden_states = hidden_states + self.pos_embed + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + intermediate_embeddings = [] + + for i, layer_module in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + ) + else: + layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) + + hidden_states = layer_outputs[0] + + # Collect embeddings from non-windowed blocks + if hasattr(layer_module, "window_size") and layer_module.window_size == 0: + intermediate_embeddings.append(hidden_states) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.neck(hidden_states) + + if not return_dict: + outputs = (hidden_states, intermediate_embeddings) + if output_hidden_states: + outputs = outputs + (all_hidden_states,) + if output_attentions: + outputs = outputs + (all_self_attentions,) + return outputs + + return SamHQVisionEncoderOutput( + last_hidden_state=hidden_states, + intermediate_embeddings=intermediate_embeddings, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class SamHQLayerNorm(nn.Module): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError(f"Unsupported data format: {self.data_format}") + self.normalized_shape = (normalized_shape,) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.data_format == "channels_last": + x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + input_dtype = x.dtype + x = x.float() + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = x.to(dtype=input_dtype) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class SamHQAttention(nn.Module): + """ + SAM_HQ's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and + values. + """ + + def __init__(self, config, downsample_rate=None): + super().__init__() + self.hidden_size = config.hidden_size + + downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate + + self.internal_dim = config.hidden_size // downsample_rate + self.num_attention_heads = config.num_attention_heads + if self.internal_dim % config.num_attention_heads != 0: + raise ValueError("num_attention_heads must divide hidden_size.") + + self.q_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.k_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.v_proj = nn.Linear(self.hidden_size, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, self.hidden_size) + + def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor: + batch, point_batch_size, n_tokens, channel = hidden_states.shape + c_per_head = channel // num_attention_heads + hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head) + return hidden_states.transpose(1, 2) + + def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor: + batch, n_heads, n_tokens, c_per_head = hidden_states.shape + hidden_states = hidden_states.transpose(1, 2) + return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head) + + def forward( + self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None + ) -> Tensor: + # Input projections + query = self.q_proj(query) + key = self.k_proj(key) + value = self.v_proj(value) + + point_batch_size = query.shape[1] + # Separate into heads + query = self._separate_heads(query, self.num_attention_heads) + key = self._separate_heads(key, self.num_attention_heads) + value = self._separate_heads(value, self.num_attention_heads) + + # SamHQAttention + _, _, _, c_per_head = query.shape + attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens + attn = attn / (c_per_head**0.5) + attn = torch.softmax(attn, dim=-1) + + if attention_similarity is not None: + attn = attn + attention_similarity + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ value + out = self._recombine_heads(out, point_batch_size) + out = self.out_proj(out) + + return out + + +class SamHQSdpaAttention(SamHQAttention): + """ + SAM_HQ's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and + values. Using SDPA instead of the default attention. + """ + + def __init__(self, config, downsample_rate=None): + super().__init__(config, downsample_rate) + + def forward( + self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None + ) -> Tensor: + # Input projections + query = self.q_proj(query) + key = self.k_proj(key) + value = self.v_proj(value) + + point_batch_size = query.shape[1] + # Separate into heads + query = self._separate_heads(query, self.num_attention_heads) + key = self._separate_heads(key, self.num_attention_heads) + value = self._separate_heads(value, self.num_attention_heads) + + # Scaled dot product attention + attn_mask = None + if attention_similarity is not None: + attn_mask = attention_similarity.unsqueeze(1).expand(-1, self.num_attention_heads, -1, -1) + + out = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask) + + # Get output + out = self._recombine_heads(out, point_batch_size) + out = self.out_proj(out) + + return out + + +SAM_HQ_ATTENTION_CLASSES = { + "eager": SamHQAttention, + "sdpa": SamHQSdpaAttention, +} + + +class SamHQTwoWayAttentionBlock(nn.Module): + def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False): + """ + A transformer block with four layers: + (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on + sparse inputs (4) cross attention of dense inputs -> sparse inputs + + Arguments: + config (`SamHQMaskDecoderConfig`): + The configuration file used to instantiate the block + attention_downsample_rate (*optionalk*, int, defaults to 2): + The downsample ratio of the block used to reduce the inner dim of the attention. + skip_first_layer_pe (*optional*, bool, defaults to `False`): + Whether or not to skip the addition of the query_point_embedding on the first layer. + """ + super().__init__() + + self.hidden_size = config.hidden_size + self.layer_norm_eps = config.layer_norm_eps + + self.self_attn = SAM_HQ_ATTENTION_CLASSES[config._attn_implementation](config, downsample_rate=1) + self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + + self.cross_attn_token_to_image = SAM_HQ_ATTENTION_CLASSES[config._attn_implementation]( + config, downsample_rate=attention_downsample_rate + ) + self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + + self.mlp = SamHQMLPBlock(config) + self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + + self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + self.cross_attn_image_to_token = SAM_HQ_ATTENTION_CLASSES[config._attn_implementation]( + config, downsample_rate=attention_downsample_rate + ) + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, + queries: Tensor, + keys: Tensor, + query_point_embedding: Tensor, + key_point_embedding: Tensor, + attention_similarity: Tensor, + output_attentions: bool = False, + ): + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(query=queries, key=queries, value=queries) + else: + query = queries + query_point_embedding + attn_out = self.self_attn(query=query, key=query, value=queries) + queries = queries + attn_out + queries = self.layer_norm1(queries) + + # Cross attention block, tokens attending to image embedding + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out = self.cross_attn_token_to_image( + query=query, key=key, value=keys, attention_similarity=attention_similarity + ) + queries = queries + attn_out + + queries = self.layer_norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.layer_norm3(queries) + + # Cross attention block, image embedding attending to tokens + query = queries + query_point_embedding + key = keys + key_point_embedding + + attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries) + keys = keys + attn_out + + keys = self.layer_norm4(keys) + + outputs = (queries, keys) + + if output_attentions: + outputs = outputs + (attn_out,) + else: + outputs = outputs + (None,) + + return outputs + + +class SamHQTwoWayTransformer(nn.Module): + def __init__(self, config: SamHQMaskDecoderConfig): + super().__init__() + self.config = config + + self.num_hidden_layers = config.num_hidden_layers + self.layers = nn.ModuleList() + + for i in range(self.num_hidden_layers): + self.layers.append(SamHQTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0))) + + self.final_attn_token_to_image = SAM_HQ_ATTENTION_CLASSES[config._attn_implementation](config) + self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size) + + def forward( + self, + point_embeddings: Tensor, + image_embeddings: Tensor, + image_positional_embeddings: Tensor, + attention_similarity: Tensor, + target_embedding=None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + all_attentions = () + + if image_embeddings is None: + raise ValueError("You have to specify an image_embedding") + + image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1) + image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1) + + # Prepare queries + queries = point_embeddings + keys = image_embeddings + + # Apply transformer blocks and final layernorm + for layer in self.layers: + if target_embedding is not None: + queries += target_embedding + + queries, keys, attention_outputs = layer( + queries=queries, + keys=keys, + query_point_embedding=point_embeddings, + key_point_embedding=image_positional_embeddings, + attention_similarity=attention_similarity, + output_attentions=output_attentions, + ) + + if output_attentions: + all_attentions = all_attentions + (attention_outputs,) + + # Apply the final attenion layer from the points to the image + query = queries + point_embeddings + key = keys + image_positional_embeddings + + attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys) + + queries = queries + attn_out + queries = self.layer_norm_final_attn(queries) + return queries, keys, all_attentions + + +class SamHQFeedForward(nn.Module): + def __init__( + self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False + ): + super().__init__() + self.num_layers = num_layers + self.activation = nn.ReLU() + self.proj_in = nn.Linear(input_dim, hidden_dim) + self.proj_out = nn.Linear(hidden_dim, output_dim) + self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)]) + self.sigmoid_output = sigmoid_output + + def forward(self, hidden_states): + hidden_states = self.proj_in(hidden_states) + hidden_states = self.activation(hidden_states) + for layer in self.layers: + hidden_states = self.activation(layer(hidden_states)) + + hidden_states = self.proj_out(hidden_states) + if self.sigmoid_output: + hidden_states = F.sigmoid(hidden_states) + return hidden_states + + +class SamHQMaskDecoder(nn.Module): + def __init__(self, config: SamHQMaskDecoderConfig): + super().__init__() + self.hidden_size = config.hidden_size + + self.num_multimask_outputs = config.num_multimask_outputs + self.num_mask_tokens = config.num_multimask_outputs + 1 + + self.iou_token = nn.Embedding(1, self.hidden_size) + self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size) + + self.transformer = SamHQTwoWayTransformer(config) + + self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2) + self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2) + self.upscale_layer_norm = SamHQLayerNorm(self.hidden_size // 4, data_format="channels_first") + self.activation = nn.GELU() + + mlps_list = [] + for _ in range(self.num_mask_tokens): + mlps_list += [SamHQFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)] + self.output_hypernetworks_mlps = nn.ModuleList(mlps_list) + + self.iou_prediction_head = SamHQFeedForward( + self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth + ) + + self.hq_token = nn.Embedding(1, self.hidden_size) + self.hq_mask_mlp = SamHQFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3) + self.num_mask_tokens = self.num_mask_tokens + 1 + + # Compress ViT features + self.compress_vit_conv1 = nn.ConvTranspose2d(config.vit_dim, self.hidden_size, kernel_size=2, stride=2) + self.compress_vit_norm = SamHQLayerNorm(self.hidden_size, data_format="channels_first") + self.compress_vit_conv2 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 8, kernel_size=2, stride=2) + + # Embedding encoder + self.encoder_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2) + self.encoder_norm = SamHQLayerNorm(self.hidden_size // 4, data_format="channels_first") + self.encoder_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2) + + # Embedding mask feature + self.mask_conv1 = nn.Conv2d(self.hidden_size // 8, self.hidden_size // 4, kernel_size=3, stride=1, padding=1) + self.mask_norm = SamHQLayerNorm(self.hidden_size // 4, data_format="channels_first") + self.mask_conv2 = nn.Conv2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=3, stride=1, padding=1) + + def forward( + self, + image_embeddings: torch.Tensor, + image_positional_embeddings: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + hq_token_only: bool, + intermediate_embeddings: Optional[List[torch.Tensor]] = None, + output_attentions: Optional[bool] = None, + attention_similarity: Optional[torch.Tensor] = None, + target_embedding: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict high-quality masks given image and prompt embeddings. + + Args: + image_embeddings (`torch.Tensor`): + The embeddings from the image encoder. + image_positional_embedding (`torch.Tensor`): + Positional encoding with the shape of image_embeddings. + sparse_prompt_embeddings (`torch.Tensor`): + The embeddings of the points and boxes. + dense_prompt_embeddings (`torch.Tensor`): + The embeddings of the mask inputs. + multimask_output (bool): + Whether to return multiple masks or a single mask. + hq_token_only (bool): + Whether to use only the high-quality token output or combine with SAM output. + intermediate_embeddings (`torch.Tensor`): + Intermediate embeddings from the vision encoder for feature fusion. + output_attentions (bool, *optional*): + Whether or not to return the attentions tensors of all attention layers. + attention_similarity (`torch.Tensor`, *optional*): + Optional tensor for attention similarity computation. + target_embedding (`torch.Tensor`, *optional*): + Optional target embedding for transformer processing. + + Returns: + `Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple of tensors containing: + - A tensor of shape `(batch_size, num_prompts, num_masks, height, width)` containing the output masks. + - A tensor of shape `(batch_size, num_prompts, num_masks)` containing the iou predictions for each mask. + - (Optional) A tuple containing attention tensors if output_attentions is True. + """ + batch_size, num_channels, height, width = image_embeddings.shape + point_batch_size = sparse_prompt_embeddings.shape[1] + + has_intermediate = intermediate_embeddings is not None and len(intermediate_embeddings) > 0 + + if has_intermediate: + vit_features = intermediate_embeddings[0].permute(0, 3, 1, 2).contiguous() + + embed_encode = self.encoder_conv1(image_embeddings) + embed_encode = self.activation(self.encoder_norm(embed_encode)) + embed_encode = self.encoder_conv2(embed_encode) + + if has_intermediate: + compressed_vit_features = self.compress_vit_conv1(vit_features) + compressed_vit_features = self.activation(self.compress_vit_norm(compressed_vit_features)) + compressed_vit_features = self.compress_vit_conv2(compressed_vit_features) + + hq_features = embed_encode + compressed_vit_features + else: + hq_features = embed_encode + + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hq_token.weight], dim=0) + output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) + + if torch.any(sparse_prompt_embeddings != 0): + tokens = torch.cat([output_tokens, sparse_prompt_embeddings], dim=2) + else: + tokens = output_tokens + point_embeddings = tokens.to(self.iou_token.weight.dtype) + + image_embeddings = image_embeddings + dense_prompt_embeddings + image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0) + image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0) + + point_embedding, image_embeddings, attentions = self.transformer( + point_embeddings=point_embeddings, + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + attention_similarity=attention_similarity, + target_embedding=target_embedding, + output_attentions=output_attentions, + ) + iou_token_out = point_embedding[:, :, 0, :] + mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :] + + image_embeddings = image_embeddings.transpose(2, 3).reshape( + batch_size * point_batch_size, num_channels, height, width + ) + + upscaled_embedding = self.upscale_conv1(image_embeddings) + upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding)) + upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding)) + + upscaled_embedding_hq = self.mask_conv1(upscaled_embedding) + upscaled_embedding_hq = self.activation(self.mask_norm(upscaled_embedding_hq)) + upscaled_embedding_hq = self.mask_conv2(upscaled_embedding_hq) + + if hq_features.shape[0] == 1: + hq_features = hq_features.repeat(batch_size * point_batch_size, 1, 1, 1) + elif hq_features.shape[0] == batch_size and batch_size * point_batch_size != batch_size: + hq_features = hq_features.repeat_interleave(point_batch_size, 0) + upscaled_embedding_hq = upscaled_embedding_hq + hq_features + + hyper_in_list = [] + for mask_token_index in range(self.num_mask_tokens): + if mask_token_index < self.num_mask_tokens - 1: + current_mlp = self.output_hypernetworks_mlps[mask_token_index] + else: + current_mlp = self.hq_mask_mlp + hyper_in_list += [current_mlp(mask_tokens_out[:, :, mask_token_index, :])] + + hyper_in = torch.stack(hyper_in_list, dim=2) + _, num_channels, height, width = upscaled_embedding.shape + upscaled_embedding = upscaled_embedding.reshape(batch_size, point_batch_size, num_channels, height * width) + upscaled_embedding_hq = upscaled_embedding_hq.reshape( + batch_size, point_batch_size, num_channels, height * width + ) + + masks_sam = (hyper_in[:, :, : self.num_mask_tokens - 1] @ upscaled_embedding).reshape( + batch_size, point_batch_size, -1, height, width + ) + masks_hq = (hyper_in[:, :, self.num_mask_tokens - 1 :] @ upscaled_embedding_hq).reshape( + batch_size, point_batch_size, -1, height, width + ) + masks = torch.cat([masks_sam, masks_hq], dim=2) + + iou_pred = self.iou_prediction_head(iou_token_out) + + if multimask_output: + mask_slice = slice(1, self.num_mask_tokens - 1) + iou_pred = iou_pred[:, :, mask_slice] + # Sort the IoU scores in descending order and get indices + iou_pred_sorted, sort_indices = torch.sort(iou_pred, dim=2, descending=True) + # Reorder the masks according to sorted scores + masks_sam = masks[:, :, mask_slice, :, :] + masks_sam = torch.gather( + masks_sam, + 2, + sort_indices[..., None, None].expand(-1, -1, -1, masks_sam.shape[3], masks_sam.shape[4]), + ) + # Update iou_pred with sorted scores + iou_pred = iou_pred_sorted + else: + mask_slice = slice(0, 1) + iou_pred = iou_pred[:, :, mask_slice] + masks_sam = masks[:, :, mask_slice, :, :] + + masks_hq = masks[:, :, slice(self.num_mask_tokens - 1, self.num_mask_tokens), :, :] + if hq_token_only: + masks = masks_hq + else: + masks = masks_sam + masks_hq + + outputs = (masks, iou_pred) + if output_attentions: + outputs = outputs + (attentions,) + else: + outputs = outputs + (None,) + + return outputs + + +class SamHQPreTrainedModel(PreTrainedModel): + config_class = SamHQConfig + base_model_prefix = "sam_hq" + main_input_name = "pixel_values" + _no_split_modules = ["SamHQVisionAttention"] + supports_gradient_checkpointing = True + _supports_sdpa = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, (SamHQLayerNorm, nn.LayerNorm)): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, SamHQVisionAttention): + if module.use_rel_pos: + module.rel_pos_h.data.zero_() + module.rel_pos_w.data.zero_() + if isinstance(module, SamHQVisionEncoder): + if module.pos_embed is not None: + module.pos_embed.data.zero_() + + +SAM_HQ_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`SamHQProcessor`]. See [`SamHQProcessor.__call__`] for + details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +SAM_HQ_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`SamHQConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + """The vision model from SAM-HQ without any head or projection on top.""", + SAM_HQ_START_DOCSTRING, +) +class SamHQVisionModel(SamHQPreTrainedModel): + config_class = SamHQVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: SamHQVisionConfig): + super().__init__(config) + self.vision_encoder = SamHQVisionEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_encoder.patch_embed + + @add_start_docstrings_to_model_forward(SAM_HQ_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SamHQVisionEncoderOutput, config_class=SamHQVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SamHQVisionEncoderOutput]: + r""" + Returns: + + """ + return self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class SamHQPositionalEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.scale = config.hidden_size // 2 + self.register_buffer("positional_embedding", self.scale * torch.randn((2, config.num_pos_feats))) + + def forward(self, input_coords, input_shape=None): + """Positionally encode points that are normalized to [0,1].""" + coordinates = input_coords.clone() + + if input_shape is not None: + coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1] + coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0] + + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coordinates = 2 * coordinates - 1 + coordinates = coordinates.to(self.positional_embedding.dtype) + coordinates = coordinates @ self.positional_embedding + coordinates = 2 * np.pi * coordinates + # outputs d_1 x ... x d_n x channel shape + return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1) + + +class SamHQMaskEmbedding(nn.Module): + def __init__(self, config: SamHQPromptEncoderConfig): + super().__init__() + self.mask_input_channels = config.mask_input_channels // 4 + self.activation = ACT2FN[config.hidden_act] + self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2) + self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2) + self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1) + self.layer_norm1 = SamHQLayerNorm( + self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first" + ) + self.layer_norm2 = SamHQLayerNorm( + self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first" + ) + + def forward(self, masks): + hidden_states = self.conv1(masks) + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.conv2(hidden_states) + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.activation(hidden_states) + dense_embeddings = self.conv3(hidden_states) + return dense_embeddings + + +class SamHQPromptEncoder(nn.Module): + def __init__(self, config: SamHQPromptEncoderConfig): + super().__init__() + self.shared_embedding = SamHQPositionalEmbedding(config.vision_config) + config = config.prompt_encoder_config + self.mask_embed = SamHQMaskEmbedding(config) + self.no_mask_embed = nn.Embedding(1, config.hidden_size) + + self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size) + self.input_image_size = config.image_size + + self.point_embed = nn.ModuleList( + [nn.Embedding(1, config.hidden_size) for i in range(config.num_point_embeddings)] + ) + self.hidden_size = config.hidden_size + self.not_a_point_embed = nn.Embedding(1, config.hidden_size) + + def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1]) + target_labels_shape = (points.shape[0], points.shape[1], 1) + padding_point = torch.zeros(target_point_shape, device=points.device) + padding_label = -torch.ones(target_labels_shape, device=labels.device) + points = torch.cat([points, padding_point], dim=2) + labels = torch.cat([labels, padding_label], dim=2) + input_shape = (self.input_image_size, self.input_image_size) + point_embedding = self.shared_embedding(points, input_shape) + + # torch.where and expanding the labels tensor is required by the ONNX export + point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding) + + # This is required for the ONNX export. The dtype, device need to be explicitly + # specified as otherwise torch.onnx.export interprets as double + point_embedding = torch.where( + labels[..., None] != -10, + point_embedding, + torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device), + ) + + point_embedding = torch.where( + (labels == 0)[:, :, :, None], + point_embedding + self.point_embed[0].weight[None, None, :, :], + point_embedding, + ) + + point_embedding = torch.where( + (labels == 1)[:, :, :, None], + point_embedding + self.point_embed[1].weight[None, None, :, :], + point_embedding, + ) + + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + batch_size, nb_boxes = boxes.shape[:2] + coords = boxes.reshape(batch_size, nb_boxes, 2, 2) + input_shape = (self.input_image_size, self.input_image_size) + corner_embedding = self.shared_embedding(coords, input_shape) + corner_embedding[:, :, 0, :] += self.point_embed[2].weight + corner_embedding[:, :, 1, :] += self.point_embed[3].weight + return corner_embedding + + def forward( + self, + input_points: Optional[Tuple[torch.Tensor, torch.Tensor]], + input_labels: Optional[torch.Tensor], + input_boxes: Optional[torch.Tensor], + input_masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense embeddings. + + Args: + points (`torch.Tensor`, *optional*): + point coordinates and labels to embed. + boxes (`torch.Tensor`, *optional*): + boxes to embed + masks (`torch.Tensor`, *optional*): + masks to embed + """ + sparse_embeddings = None + batch_size = 1 + target_device = self.shared_embedding.positional_embedding.device + if input_points is not None: + batch_size, point_batch_size = input_points.shape[:2] + if input_labels is None: + raise ValueError("If points are provided, labels must also be provided.") + point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None)) + sparse_embeddings = point_embeddings + if input_boxes is not None: + batch_size = input_boxes.shape[0] + box_embeddings = self._embed_boxes(input_boxes) + if sparse_embeddings is None: + sparse_embeddings = box_embeddings + else: + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2) + if input_masks is not None: + dense_embeddings = self.mask_embed(input_masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + if sparse_embeddings is None: + sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device) + + return sparse_embeddings, dense_embeddings + + +SAM_HQ_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`SamHQProcessor`]. See [`SamHQProcessor.__call__`] for + details. + input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`): + Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much + better results. The points can be obtained by passing a list of list of list to the processor that will + create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the + second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict + per input point), the third dimension is the number of points per segmentation mask (it is possible to pass + multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) + coordinates of the point. If a different number of points is passed either for each image, or for each + mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the + computation of the embedding will be skipped for these points using the labels. + input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`): + Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the + official implementation, there are 3 types of labels + + - `1`: the point is a point that contains the object of interest + - `0`: the point is a point that does not contain the object of interest + - `-1`: the point corresponds to the background + + We added the label: + + - `-10`: the point is a padding point, thus should be ignored by the prompt encoder + + The padding labels should be automatically done by the processor. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`): + Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to + much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, + that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch + size, the number of boxes per image and the coordinates of the top left and bottom right point of the box. + In the order (`x1`, `y1`, `x2`, `y2`): + + - `x1`: the x coordinate of the top left point of the input box + - `y1`: the y coordinate of the top left point of the input box + - `x2`: the x coordinate of the bottom right point of the input box + - `y2`: the y coordinate of the bottom right point of the input box + + input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`): + SAM_HQ model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to + generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be + manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). + + image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`): + Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory + efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` + method, and then feed them to the `forward` method instead of feeding the `pixel_values`. + multimask_output (`bool`, *optional*): + In the original implementation and paper, the model always outputs 3 masks per image (or per point / per + bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the + "best" mask, by specifying `multimask_output=False`. + attention_similarity (`torch.FloatTensor`, *optional*): + Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the + model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048). + target_embedding (`torch.FloatTensor`, *optional*): + Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case + the model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "Segment Anything Model HQ (SAM-HQ) for generating masks,given an input image and", + " optional 2D location and bounding boxes.", + SAM_HQ_START_DOCSTRING, +) +class SamHQModel(SamHQPreTrainedModel): + _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"] + + _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"] + + def __init__(self, config): + super().__init__(config) + self.shared_image_embedding = SamHQPositionalEmbedding(config.vision_config) + self.vision_encoder = SamHQVisionEncoder(config.vision_config) + self.prompt_encoder = SamHQPromptEncoder(config) + + self.mask_decoder = SamHQMaskDecoder(config.mask_decoder_config) + + self.post_init() + + def _tie_weights(self): + self.prompt_encoder.shared_embedding.positional_embedding.data = ( + self.shared_image_embedding.positional_embedding.data + ) + + def get_input_embeddings(self): + return self.vision_encoder.get_input_embeddings() + + def get_image_wide_positional_embeddings(self): + size = self.config.prompt_encoder_config.image_embedding_size + target_device = self.shared_image_embedding.positional_embedding.device + target_dtype = self.shared_image_embedding.positional_embedding.dtype + grid = torch.ones((size, size), device=target_device, dtype=target_dtype) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / size + x_embed = x_embed / size + + positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1)) + return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width + + @torch.no_grad() + def get_image_embeddings( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Returns the image embeddings by passing the pixel values through the vision encoder. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Input pixel values + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + """ + vision_output = self.vision_encoder( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + image_embeddings = vision_output[0] + intermediate_embeddings = vision_output[1] + + return image_embeddings, intermediate_embeddings + + @torch.no_grad() + def get_prompt_embeddings( + self, + input_points: Optional[torch.FloatTensor] = None, + input_labels: Optional[torch.LongTensor] = None, + input_boxes: Optional[torch.FloatTensor] = None, + input_masks: Optional[torch.LongTensor] = None, + ): + r""" + Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder. + + Args: + input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`): + Optional input points for the prompt encoder. The padding of the point is automatically done by the + processor. `point_batch_size` refers to the number of masks that we want the model to predict per + point. The model will output `point_batch_size` times 3 masks in total. + input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`): + Optional input labels for the prompt encoder. The padding of the labels is automatically done by the + processor, or can be fed by the user. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`): + Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the + processor. users can also pass manually the input boxes. + input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`): + Optional input masks for the prompt encoder. + """ + prompt_output = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + return prompt_output + + @can_return_tuple + @add_start_docstrings_to_model_forward(SAM_HQ_INPUTS_DOCSTRING) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + input_points: Optional[torch.FloatTensor] = None, + input_labels: Optional[torch.LongTensor] = None, + input_boxes: Optional[torch.FloatTensor] = None, + input_masks: Optional[torch.LongTensor] = None, + image_embeddings: Optional[torch.FloatTensor] = None, + multimask_output: bool = True, + hq_token_only: bool = False, + attention_similarity: Optional[torch.FloatTensor] = None, + target_embedding: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + intermediate_embeddings: Optional[List[torch.FloatTensor]] = None, + **kwargs, + ) -> List[Dict[str, torch.Tensor]]: + r""" + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`SamHQProcessor`]. See [`SamHQProcessor.__call__`] for + details. + input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`): + Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much + better results. The points can be obtained by passing a list of list of list to the processor that will + create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the + second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict + per input point), the third dimension is the number of points per segmentation mask (it is possible to pass + multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) + coordinates of the point. If a different number of points is passed either for each image, or for each + mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the + computation of the embedding will be skipped for these points using the labels. + input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`): + Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the + official implementation, there are 3 types of labels + + - `1`: the point is a point that contains the object of interest + - `0`: the point is a point that does not contain the object of interest + - `-1`: the point corresponds to the background + + We added the label: + + - `-10`: the point is a padding point, thus should be ignored by the prompt encoder + + The padding labels should be automatically done by the processor. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`): + Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to + much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, + that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch + size, the number of boxes per image and the coordinates of the top left and botton right point of the box. + In the order (`x1`, `y1`, `x2`, `y2`): + + - `x1`: the x coordinate of the top left point of the input box + - `y1`: the y coordinate of the top left point of the input box + - `x2`: the x coordinate of the bottom right point of the input box + - `y2`: the y coordinate of the bottom right point of the input box + + input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`): + SAM_HQ model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to + generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be + manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). + + image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`): + Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory + efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` + method, and then feed them to the `forward` method instead of feeding the `pixel_values`. + multimask_output (`bool`, *optional*): + In the original implementation and paper, the model always outputs 3 masks per image (or per point / per + bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the + "best" mask, by specifying `multimask_output=False`. + hq_token_only (`bool`, *optional*, defaults to `False`): + Whether to use only the HQ token path for mask generation. When False, combines both standard and HQ paths. + This is specific to SAM-HQ's architecture. + attention_similarity (`torch.FloatTensor`, *optional*): + Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the + model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048). + target_embedding (`torch.FloatTensor`, *optional*): + Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case + the model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + intermediate_embeddings (`List[torch.FloatTensor]`, *optional*): + Intermediate embeddings from vision encoder's non-windowed blocks, used by SAM-HQ for enhanced mask quality. + Required when providing pre-computed image_embeddings instead of pixel_values. + Example: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoModel, AutoProcessor + + >>> model = AutoModel.from_pretrained("sushmanth/sam_hq_vit_b") + >>> processor = AutoProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + + >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png" + >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + >>> input_points = [[[400, 650]]] # 2D location of a window on the car + >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt") + + >>> # Get high-quality segmentation mask + >>> outputs = model(**inputs) + + >>> # For high-quality mask only + >>> outputs = model(**inputs, hq_token_only=True) + + >>> # Postprocess masks + >>> masks = processor.post_process_masks( + ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"] + ... ) + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None and image_embeddings is None: + raise ValueError("Either pixel_values or image_embeddings must be provided.") + + if pixel_values is not None and image_embeddings is not None: + raise ValueError("Only one of pixel_values and image_embeddings can be provided.") + + if input_points is not None and len(input_points.shape) != 4: + raise ValueError( + "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`." + f" got {input_points.shape}." + ) + + if input_boxes is not None and len(input_boxes.shape) != 3: + raise ValueError( + "The input_boxes must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`." + f" got {input_boxes.shape}." + ) + + # Add validation for point and box batch sizes + if input_points is not None and input_boxes is not None: + point_batch_size = input_points.shape[1] + box_batch_size = input_boxes.shape[1] + if point_batch_size != box_batch_size: + raise ValueError( + "You should provide as many bounding boxes as input points per box. Got {} and {}.".format( + point_batch_size, box_batch_size + ) + ) + + image_positional_embeddings = self.get_image_wide_positional_embeddings() + # repeat with batch size + batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0] + image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1) + + vision_attentions = None + vision_hidden_states = None + + if pixel_values is not None: + vision_outputs = self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if return_dict: + image_embeddings = vision_outputs.last_hidden_state + intermediate_embeddings = vision_outputs.intermediate_embeddings + if output_hidden_states: + vision_hidden_states = vision_outputs.hidden_states + if output_attentions: + vision_attentions = vision_outputs.attentions + else: + image_embeddings = vision_outputs[0] + intermediate_embeddings = vision_outputs[1] + if output_hidden_states: + vision_hidden_states = vision_outputs[2] + if output_attentions: + vision_attentions = vision_outputs[-1] + + if input_points is not None and input_labels is None: + input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device) + + sparse_embeddings, dense_embeddings = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + + # Predict masks + low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder( + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + hq_token_only=hq_token_only, + intermediate_embeddings=intermediate_embeddings, + attention_similarity=attention_similarity, + target_embedding=target_embedding, + output_attentions=output_attentions, + ) + + if not return_dict: + output = (iou_predictions, low_res_masks) + if output_hidden_states: + output = output + (vision_hidden_states,) + + if output_attentions: + output = output + (vision_attentions, mask_decoder_attentions) + return output + + return SamHQImageSegmentationOutput( + iou_scores=iou_predictions, + pred_masks=low_res_masks, + vision_hidden_states=vision_hidden_states, + vision_attentions=vision_attentions, + mask_decoder_attentions=mask_decoder_attentions, + ) + + +__all__ = ["SamHQModel", "SamHQPreTrainedModel", "SamHQVisionModel"] diff --git a/src/transformers/models/sam_hq/modular_sam_hq.py b/src/transformers/models/sam_hq/modular_sam_hq.py new file mode 100644 index 0000000000..b86e300006 --- /dev/null +++ b/src/transformers/models/sam_hq/modular_sam_hq.py @@ -0,0 +1,737 @@ +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...utils import add_start_docstrings, logging +from ..sam.configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig +from ..sam.modeling_sam import ( + SamFeedForward, + SamImageSegmentationOutput, + SamLayerNorm, + SamModel, + SamPreTrainedModel, + SamTwoWayTransformer, + SamVisionEncoder, + SamVisionEncoderOutput, + SamVisionModel, +) + + +logger = logging.get_logger(__name__) + + +class SamHQPromptEncoderConfig(SamPromptEncoderConfig): + r""" + This is the configuration class to store the configuration of a [`SamHQPromptEncoderModel`].The [`SamHQPromptEncoderModel`] + module is used to encode the input 2D points and bounding boxes. Instantiating a configuration defaults will yield a + similar configuration to that of the SAM_HQ model. The configuration is used to store the configuration of the model. + [Uminosachi/sam-hq](https://huggingface.co/Uminosachi/sam-hq) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model's output.Read the documentation from + [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden states. + image_size (`int`, *optional*, defaults to 1024): + The expected output resolution of the image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + mask_input_channels (`int`, *optional*, defaults to 16): + The number of channels to be fed to the `MaskDecoder` module. + num_point_embeddings (`int`, *optional*, defaults to 4): + The number of point embeddings to be used. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the encoder and pooler. + """ + + pass + + +class SamHQVisionConfig(SamVisionConfig): + pass + + +class SamHQMaskDecoderConfig(SamMaskDecoderConfig): + r""" + vit_dim (`int`, *optional*, defaults to 768): + Dimensionality of the Vision Transformer (ViT) used in the `SamHQMaskDecoder` module. + """ + + def __init__( + self, + vit_dim=768, + **super_kwargs, + ): + super().__init__(**super_kwargs) + self.vit_dim = vit_dim + + +class SamHQConfig(SamConfig): + r""" + [`SamHQConfig`] is the configuration class to store the configuration of a [`SamHQModel`]. It is used to instantiate a + SAM-HQ model according to the specified arguments, defining the vision model, prompt-encoder model and mask decoder + configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the + SAM-HQ-ViT-H [sushmanth/sam_hq_vit_h](https://huggingface.co/sushmanth/sam_hq_vit_h) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (Union[`dict`, `SamHQVisionConfig`], *optional*): + Dictionary of configuration options used to initialize [`SamHQVisionConfig`]. + prompt_encoder_config (Union[`dict`, `SamHQPromptEncoderConfig`], *optional*): + Dictionary of configuration options used to initialize [`SamHQPromptEncoderConfig`]. + mask_decoder_config (Union[`dict`, `SamHQMaskDecoderConfig`], *optional*): + Dictionary of configuration options used to initialize [`SamHQMaskDecoderConfig`]. + kwargs (*optional*): + Dictionary of keyword arguments. + """ + + pass + + +@dataclass +class SamHQVisionEncoderOutput(SamVisionEncoderOutput): + """ + intermediate_embeddings (`list(torch.FloatTensor)`, *optional*): + A list of intermediate embeddings collected from certain blocks within the model, typically those without + windowed attention. Each element in the list is of shape `(batch_size, sequence_length, hidden_size)`. + This is specific to SAM-HQ and not present in base SAM. + """ + + intermediate_embeddings: Optional[List[torch.FloatTensor]] = None + + +@dataclass +class SamHQImageSegmentationOutput(SamImageSegmentationOutput): + pass + + +class SamHQVisionEncoder(SamVisionEncoder): + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SamHQVisionEncoderOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.patch_embed(pixel_values) + if self.pos_embed is not None: + hidden_states = hidden_states + self.pos_embed + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + intermediate_embeddings = [] + + for i, layer_module in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + ) + else: + layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) + + hidden_states = layer_outputs[0] + + # Collect embeddings from non-windowed blocks + if hasattr(layer_module, "window_size") and layer_module.window_size == 0: + intermediate_embeddings.append(hidden_states) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.neck(hidden_states) + + if not return_dict: + outputs = (hidden_states, intermediate_embeddings) + if output_hidden_states: + outputs = outputs + (all_hidden_states,) + if output_attentions: + outputs = outputs + (all_self_attentions,) + return outputs + + return SamHQVisionEncoderOutput( + last_hidden_state=hidden_states, + intermediate_embeddings=intermediate_embeddings, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class SamHQLayerNorm(SamLayerNorm): + pass + + +class SamHQTwoWayTransformer(SamTwoWayTransformer): + pass + + +class SamHQFeedForward(SamFeedForward): + pass + + +class SamHQMaskDecoder(nn.Module): + def __init__(self, config: SamHQMaskDecoderConfig): + super().__init__() + self.hidden_size = config.hidden_size + + self.num_multimask_outputs = config.num_multimask_outputs + self.num_mask_tokens = config.num_multimask_outputs + 1 + + self.iou_token = nn.Embedding(1, self.hidden_size) + self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size) + + self.transformer = SamHQTwoWayTransformer(config) + + self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2) + self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2) + self.upscale_layer_norm = SamHQLayerNorm(self.hidden_size // 4, data_format="channels_first") + self.activation = nn.GELU() + + mlps_list = [] + for _ in range(self.num_mask_tokens): + mlps_list += [SamHQFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)] + self.output_hypernetworks_mlps = nn.ModuleList(mlps_list) + + self.iou_prediction_head = SamHQFeedForward( + self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth + ) + + self.hq_token = nn.Embedding(1, self.hidden_size) + self.hq_mask_mlp = SamHQFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3) + self.num_mask_tokens = self.num_mask_tokens + 1 + + # Compress ViT features + self.compress_vit_conv1 = nn.ConvTranspose2d(config.vit_dim, self.hidden_size, kernel_size=2, stride=2) + self.compress_vit_norm = SamHQLayerNorm(self.hidden_size, data_format="channels_first") + self.compress_vit_conv2 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 8, kernel_size=2, stride=2) + + # Embedding encoder + self.encoder_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2) + self.encoder_norm = SamHQLayerNorm(self.hidden_size // 4, data_format="channels_first") + self.encoder_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2) + + # Embedding mask feature + self.mask_conv1 = nn.Conv2d(self.hidden_size // 8, self.hidden_size // 4, kernel_size=3, stride=1, padding=1) + self.mask_norm = SamHQLayerNorm(self.hidden_size // 4, data_format="channels_first") + self.mask_conv2 = nn.Conv2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=3, stride=1, padding=1) + + def forward( + self, + image_embeddings: torch.Tensor, + image_positional_embeddings: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + hq_token_only: bool, + intermediate_embeddings: Optional[List[torch.Tensor]] = None, + output_attentions: Optional[bool] = None, + attention_similarity: Optional[torch.Tensor] = None, + target_embedding: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict high-quality masks given image and prompt embeddings. + + Args: + image_embeddings (`torch.Tensor`): + The embeddings from the image encoder. + image_positional_embedding (`torch.Tensor`): + Positional encoding with the shape of image_embeddings. + sparse_prompt_embeddings (`torch.Tensor`): + The embeddings of the points and boxes. + dense_prompt_embeddings (`torch.Tensor`): + The embeddings of the mask inputs. + multimask_output (bool): + Whether to return multiple masks or a single mask. + hq_token_only (bool): + Whether to use only the high-quality token output or combine with SAM output. + intermediate_embeddings (`torch.Tensor`): + Intermediate embeddings from the vision encoder for feature fusion. + output_attentions (bool, *optional*): + Whether or not to return the attentions tensors of all attention layers. + attention_similarity (`torch.Tensor`, *optional*): + Optional tensor for attention similarity computation. + target_embedding (`torch.Tensor`, *optional*): + Optional target embedding for transformer processing. + + Returns: + `Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple of tensors containing: + - A tensor of shape `(batch_size, num_prompts, num_masks, height, width)` containing the output masks. + - A tensor of shape `(batch_size, num_prompts, num_masks)` containing the iou predictions for each mask. + - (Optional) A tuple containing attention tensors if output_attentions is True. + """ + batch_size, num_channels, height, width = image_embeddings.shape + point_batch_size = sparse_prompt_embeddings.shape[1] + + has_intermediate = intermediate_embeddings is not None and len(intermediate_embeddings) > 0 + + if has_intermediate: + vit_features = intermediate_embeddings[0].permute(0, 3, 1, 2).contiguous() + + embed_encode = self.encoder_conv1(image_embeddings) + embed_encode = self.activation(self.encoder_norm(embed_encode)) + embed_encode = self.encoder_conv2(embed_encode) + + if has_intermediate: + compressed_vit_features = self.compress_vit_conv1(vit_features) + compressed_vit_features = self.activation(self.compress_vit_norm(compressed_vit_features)) + compressed_vit_features = self.compress_vit_conv2(compressed_vit_features) + + hq_features = embed_encode + compressed_vit_features + else: + hq_features = embed_encode + + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hq_token.weight], dim=0) + output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) + + if torch.any(sparse_prompt_embeddings != 0): + tokens = torch.cat([output_tokens, sparse_prompt_embeddings], dim=2) + else: + tokens = output_tokens + point_embeddings = tokens.to(self.iou_token.weight.dtype) + + image_embeddings = image_embeddings + dense_prompt_embeddings + image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0) + image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0) + + point_embedding, image_embeddings, attentions = self.transformer( + point_embeddings=point_embeddings, + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + attention_similarity=attention_similarity, + target_embedding=target_embedding, + output_attentions=output_attentions, + ) + iou_token_out = point_embedding[:, :, 0, :] + mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :] + + image_embeddings = image_embeddings.transpose(2, 3).reshape( + batch_size * point_batch_size, num_channels, height, width + ) + + upscaled_embedding = self.upscale_conv1(image_embeddings) + upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding)) + upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding)) + + upscaled_embedding_hq = self.mask_conv1(upscaled_embedding) + upscaled_embedding_hq = self.activation(self.mask_norm(upscaled_embedding_hq)) + upscaled_embedding_hq = self.mask_conv2(upscaled_embedding_hq) + + if hq_features.shape[0] == 1: + hq_features = hq_features.repeat(batch_size * point_batch_size, 1, 1, 1) + elif hq_features.shape[0] == batch_size and batch_size * point_batch_size != batch_size: + hq_features = hq_features.repeat_interleave(point_batch_size, 0) + upscaled_embedding_hq = upscaled_embedding_hq + hq_features + + hyper_in_list = [] + for mask_token_index in range(self.num_mask_tokens): + if mask_token_index < self.num_mask_tokens - 1: + current_mlp = self.output_hypernetworks_mlps[mask_token_index] + else: + current_mlp = self.hq_mask_mlp + hyper_in_list += [current_mlp(mask_tokens_out[:, :, mask_token_index, :])] + + hyper_in = torch.stack(hyper_in_list, dim=2) + _, num_channels, height, width = upscaled_embedding.shape + upscaled_embedding = upscaled_embedding.reshape(batch_size, point_batch_size, num_channels, height * width) + upscaled_embedding_hq = upscaled_embedding_hq.reshape( + batch_size, point_batch_size, num_channels, height * width + ) + + masks_sam = (hyper_in[:, :, : self.num_mask_tokens - 1] @ upscaled_embedding).reshape( + batch_size, point_batch_size, -1, height, width + ) + masks_hq = (hyper_in[:, :, self.num_mask_tokens - 1 :] @ upscaled_embedding_hq).reshape( + batch_size, point_batch_size, -1, height, width + ) + masks = torch.cat([masks_sam, masks_hq], dim=2) + + iou_pred = self.iou_prediction_head(iou_token_out) + + if multimask_output: + mask_slice = slice(1, self.num_mask_tokens - 1) + iou_pred = iou_pred[:, :, mask_slice] + # Sort the IoU scores in descending order and get indices + iou_pred_sorted, sort_indices = torch.sort(iou_pred, dim=2, descending=True) + # Reorder the masks according to sorted scores + masks_sam = masks[:, :, mask_slice, :, :] + masks_sam = torch.gather( + masks_sam, + 2, + sort_indices[..., None, None].expand(-1, -1, -1, masks_sam.shape[3], masks_sam.shape[4]), + ) + # Update iou_pred with sorted scores + iou_pred = iou_pred_sorted + else: + mask_slice = slice(0, 1) + iou_pred = iou_pred[:, :, mask_slice] + masks_sam = masks[:, :, mask_slice, :, :] + + masks_hq = masks[:, :, slice(self.num_mask_tokens - 1, self.num_mask_tokens), :, :] + if hq_token_only: + masks = masks_hq + else: + masks = masks_sam + masks_hq + + outputs = (masks, iou_pred) + if output_attentions: + outputs = outputs + (attentions,) + else: + outputs = outputs + (None,) + + return outputs + + +class SamHQPreTrainedModel(SamPreTrainedModel): + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, SamHQVisionEncoder): + if module.pos_embed is not None: + module.pos_embed.data.zero_() + + +SAM_HQ_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`SamHQConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + """The vision model from SAM-HQ without any head or projection on top.""", + SAM_HQ_START_DOCSTRING, +) +class SamHQVisionModel(SamVisionModel): + pass + + +@add_start_docstrings( + "Segment Anything Model HQ (SAM-HQ) for generating masks,given an input image and", + " optional 2D location and bounding boxes.", + SAM_HQ_START_DOCSTRING, +) +class SamHQModel(SamModel): + _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"] + + _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"] + + def __init__(self, config): + super().__init__(config) + self.vision_encoder = SamHQVisionEncoder(config.vision_config) + + self.mask_decoder = SamHQMaskDecoder(config.mask_decoder_config) + + self.post_init() + + @torch.no_grad() + def get_image_embeddings( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + Returns the image embeddings by passing the pixel values through the vision encoder. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Input pixel values + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + """ + vision_output = self.vision_encoder( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + image_embeddings = vision_output[0] + intermediate_embeddings = vision_output[1] + + return image_embeddings, intermediate_embeddings + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + input_points: Optional[torch.FloatTensor] = None, + input_labels: Optional[torch.LongTensor] = None, + input_boxes: Optional[torch.FloatTensor] = None, + input_masks: Optional[torch.LongTensor] = None, + image_embeddings: Optional[torch.FloatTensor] = None, + multimask_output: bool = True, + hq_token_only: bool = False, + attention_similarity: Optional[torch.FloatTensor] = None, + target_embedding: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + intermediate_embeddings: Optional[List[torch.FloatTensor]] = None, + **kwargs, + ) -> List[Dict[str, torch.Tensor]]: + r""" + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`SamHQProcessor`]. See [`SamHQProcessor.__call__`] for + details. + input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`): + Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much + better results. The points can be obtained by passing a list of list of list to the processor that will + create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the + second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict + per input point), the third dimension is the number of points per segmentation mask (it is possible to pass + multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal) + coordinates of the point. If a different number of points is passed either for each image, or for each + mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the + computation of the embedding will be skipped for these points using the labels. + input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`): + Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the + official implementation, there are 3 types of labels + + - `1`: the point is a point that contains the object of interest + - `0`: the point is a point that does not contain the object of interest + - `-1`: the point corresponds to the background + + We added the label: + + - `-10`: the point is a padding point, thus should be ignored by the prompt encoder + + The padding labels should be automatically done by the processor. + input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`): + Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to + much better generated masks. The boxes can be obtained by passing a list of list of list to the processor, + that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch + size, the number of boxes per image and the coordinates of the top left and botton right point of the box. + In the order (`x1`, `y1`, `x2`, `y2`): + + - `x1`: the x coordinate of the top left point of the input box + - `y1`: the y coordinate of the top left point of the input box + - `x2`: the x coordinate of the bottom right point of the input box + - `y2`: the y coordinate of the bottom right point of the input box + + input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`): + SAM_HQ model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to + generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be + manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`). + + image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`): + Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory + efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings` + method, and then feed them to the `forward` method instead of feeding the `pixel_values`. + multimask_output (`bool`, *optional*): + In the original implementation and paper, the model always outputs 3 masks per image (or per point / per + bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the + "best" mask, by specifying `multimask_output=False`. + hq_token_only (`bool`, *optional*, defaults to `False`): + Whether to use only the HQ token path for mask generation. When False, combines both standard and HQ paths. + This is specific to SAM-HQ's architecture. + attention_similarity (`torch.FloatTensor`, *optional*): + Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the + model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048). + target_embedding (`torch.FloatTensor`, *optional*): + Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case + the model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + intermediate_embeddings (`List[torch.FloatTensor]`, *optional*): + Intermediate embeddings from vision encoder's non-windowed blocks, used by SAM-HQ for enhanced mask quality. + Required when providing pre-computed image_embeddings instead of pixel_values. + Example: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoModel, AutoProcessor + + >>> model = AutoModel.from_pretrained("sushmanth/sam_hq_vit_b") + >>> processor = AutoProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + + >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png" + >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + >>> input_points = [[[400, 650]]] # 2D location of a window on the car + >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt") + + >>> # Get high-quality segmentation mask + >>> outputs = model(**inputs) + + >>> # For high-quality mask only + >>> outputs = model(**inputs, hq_token_only=True) + + >>> # Postprocess masks + >>> masks = processor.post_process_masks( + ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"] + ... ) + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None and image_embeddings is None: + raise ValueError("Either pixel_values or image_embeddings must be provided.") + + if pixel_values is not None and image_embeddings is not None: + raise ValueError("Only one of pixel_values and image_embeddings can be provided.") + + if input_points is not None and len(input_points.shape) != 4: + raise ValueError( + "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`." + f" got {input_points.shape}." + ) + + if input_boxes is not None and len(input_boxes.shape) != 3: + raise ValueError( + "The input_boxes must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`." + f" got {input_boxes.shape}." + ) + + # Add validation for point and box batch sizes + if input_points is not None and input_boxes is not None: + point_batch_size = input_points.shape[1] + box_batch_size = input_boxes.shape[1] + if point_batch_size != box_batch_size: + raise ValueError( + "You should provide as many bounding boxes as input points per box. Got {} and {}.".format( + point_batch_size, box_batch_size + ) + ) + + image_positional_embeddings = self.get_image_wide_positional_embeddings() + # repeat with batch size + batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0] + image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1) + + vision_attentions = None + vision_hidden_states = None + + if pixel_values is not None: + vision_outputs = self.vision_encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if return_dict: + image_embeddings = vision_outputs.last_hidden_state + intermediate_embeddings = vision_outputs.intermediate_embeddings + if output_hidden_states: + vision_hidden_states = vision_outputs.hidden_states + if output_attentions: + vision_attentions = vision_outputs.attentions + else: + image_embeddings = vision_outputs[0] + intermediate_embeddings = vision_outputs[1] + if output_hidden_states: + vision_hidden_states = vision_outputs[2] + if output_attentions: + vision_attentions = vision_outputs[-1] + + if input_points is not None and input_labels is None: + input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device) + + sparse_embeddings, dense_embeddings = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + input_masks=input_masks, + ) + + # Predict masks + low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder( + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + hq_token_only=hq_token_only, + intermediate_embeddings=intermediate_embeddings, + attention_similarity=attention_similarity, + target_embedding=target_embedding, + output_attentions=output_attentions, + ) + + if not return_dict: + output = (iou_predictions, low_res_masks) + if output_hidden_states: + output = output + (vision_hidden_states,) + + if output_attentions: + output = output + (vision_attentions, mask_decoder_attentions) + return output + + return SamHQImageSegmentationOutput( + iou_scores=iou_predictions, + pred_masks=low_res_masks, + vision_hidden_states=vision_hidden_states, + vision_attentions=vision_attentions, + mask_decoder_attentions=mask_decoder_attentions, + ) + + +__all__ = [ + "SamHQVisionConfig", + "SamHQMaskDecoderConfig", + "SamHQPromptEncoderConfig", + "SamHQConfig", + "SamHQModel", + "SamHQPreTrainedModel", + "SamHQVisionModel", +] diff --git a/src/transformers/models/sam_hq/processing_samhq.py b/src/transformers/models/sam_hq/processing_samhq.py new file mode 100644 index 0000000000..1be26ce362 --- /dev/null +++ b/src/transformers/models/sam_hq/processing_samhq.py @@ -0,0 +1,330 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for SAMHQ. +""" + +from copy import deepcopy +from typing import List, Optional, Union + +import numpy as np + +from ...image_utils import ImageInput, VideoInput +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import AudioInput, BatchEncoding, PreTokenizedInput, TextInput +from ...utils import is_torch_available + + +if is_torch_available(): + import torch + + +class SamHQImagesKwargs(ImagesKwargs): + segmentation_maps: Optional[ImageInput] + input_points: Optional[List[List[float]]] + input_labels: Optional[List[List[int]]] + input_boxes: Optional[List[List[List[float]]]] + point_pad_value: Optional[int] + + +class SamHQProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: SamHQImagesKwargs + _defaults = { + "images_kwargs": { + "point_pad_value": None, + } + } + + +class SamHQProcessor(ProcessorMixin): + r""" + Constructs a SAM HQ processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a + single processor. + + [`SamHQProcessor`] offers all the functionalities of [`SamImageProcessor`]. See the docstring of + [`~SamImageProcessor.__call__`] for more information. + + Args: + image_processor (`SamImageProcessor`): + An instance of [`SamImageProcessor`]. The image processor is a required input. + """ + + attributes = ["image_processor"] + image_processor_class = "SamImageProcessor" + + optional_call_args = [ + "segmentation_maps", + "input_points", + "input_labels", + "input_boxes", + ] + + def __init__(self, image_processor): + super().__init__(image_processor) + # Ensure image_processor is properly initialized + if not hasattr(self, "image_processor"): + raise ValueError("image_processor was not properly initialized") + if not hasattr(self.image_processor, "size"): + raise ValueError("image_processor.size is not set") + self.target_size = self.image_processor.size["longest_edge"] + + def __call__( + self, + images: Optional[ImageInput] = None, + # The following is to capture `segmentation_maps`, `input_points`, `input_labels` and `input_boxes` + # arguments that may be passed as a positional argument. + # See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details, + # or this conversation for more context: + # https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116 + # This behavior is only needed for backward compatibility and will be removed in future versions. + *args, # to be deprecated + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + audio: Optional[AudioInput] = None, + video: Optional[VideoInput] = None, + **kwargs: Unpack[SamHQProcessorKwargs], + ) -> BatchEncoding: + """ + This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D + points and bounding boxes for the model if they are provided. + """ + output_kwargs = self._merge_kwargs( + SamHQProcessorKwargs, + tokenizer_init_kwargs={}, + **kwargs, + **self.prepare_and_validate_optional_call_args(*args), + ) + + input_points = output_kwargs["images_kwargs"].pop("input_points", None) + input_labels = output_kwargs["images_kwargs"].pop("input_labels", None) + input_boxes = output_kwargs["images_kwargs"].pop("input_boxes", None) + + encoding_image_processor = self.image_processor( + images, + **output_kwargs["images_kwargs"], + ) + + original_sizes = encoding_image_processor["original_sizes"] + + if hasattr(original_sizes, "numpy"): + original_sizes = original_sizes.numpy() + + input_points, input_labels, input_boxes = self._check_and_preprocess_points( + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + ) + + encoding_image_processor = self._normalize_and_convert( + encoding_image_processor, + original_sizes, + input_points=input_points, + input_labels=input_labels, + input_boxes=input_boxes, + return_tensors=output_kwargs["common_kwargs"].get("return_tensors"), + point_pad_value=output_kwargs["images_kwargs"].get("point_pad_value"), + ) + + return encoding_image_processor + + def _normalize_and_convert( + self, + encoding_image_processor, + original_sizes, + input_points=None, + input_labels=None, + input_boxes=None, + return_tensors="pt", + point_pad_value=-10, + ): + """ + Normalize and convert the image processor output to the expected format. + """ + # Process input points + if input_points is not None: + input_points = self._normalize_batch_coordinates(input_points, original_sizes) + + if not all(point.shape == input_points[0].shape for point in input_points): + if input_labels is not None: + input_points, input_labels = self._pad_points_and_labels( + input_points, input_labels, point_pad_value + ) + + input_points = np.array(input_points) + + # Process input labels + if input_labels is not None: + input_labels = np.array(input_labels) + + # Process input boxes + if input_boxes is not None: + input_boxes = self._normalize_batch_coordinates(input_boxes, original_sizes, is_bounding_box=True) + input_boxes = np.array(input_boxes) + + # Update processor with converted inputs + if input_boxes is not None: + encoding_image_processor["input_boxes"] = self._to_tensor(input_boxes, 3, return_tensors) + if input_points is not None: + encoding_image_processor["input_points"] = self._to_tensor(input_points, 4, return_tensors) + if input_labels is not None: + encoding_image_processor["input_labels"] = self._to_tensor(input_labels, 3, return_tensors) + + return encoding_image_processor + + def _pad_points_and_labels(self, input_points, input_labels, point_pad_value): + r""" + The method pads the 2D points and labels to the maximum number of points in the batch. + """ + expected_nb_points = max([point.shape[0] for point in input_points]) + processed_input_points = [] + for i, point in enumerate(input_points): + if point.shape[0] != expected_nb_points: + point = np.concatenate( + [point, np.zeros((expected_nb_points - point.shape[0], 2)) + point_pad_value], axis=0 + ) + input_labels[i] = np.append(input_labels[i], [point_pad_value]) + processed_input_points.append(point) + input_points = processed_input_points + return input_points, input_labels + + def _normalize_coordinates( + self, target_size: int, coords: np.ndarray, original_size, is_bounding_box=False + ) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the original image size in (H,W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.image_processor._get_preprocess_shape(original_size, longest_edge=target_size) + coords = deepcopy(coords).astype(float) + + if is_bounding_box: + coords = coords.reshape(-1, 2, 2) + + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + + if is_bounding_box: + coords = coords.reshape(-1, 4) + + return coords + + def _preprocess_input(self, inputs, error_message, expected_nesting=1, dtype=None): + """ + Preprocess input by converting torch tensors to numpy arrays and validating structure. + + Args: + inputs: The input to process + error_message: Error message if validation fails + expected_nesting: Expected nesting level (1 for points/labels, 2 for boxes) + dtype: Optional data type for numpy array conversion + + Returns: + Processed input as list of numpy arrays or None + """ + if inputs is None: + return None + + # Convert torch tensor to list if applicable + if hasattr(inputs, "numpy"): + inputs = inputs.numpy().tolist() + + # Validate structure based on expected nesting + valid = isinstance(inputs, list) + current = inputs + + for _ in range(expected_nesting): + if not valid or not current: + break + valid = valid and isinstance(current[0], list) + current = current[0] if current else None + + if not valid: + raise ValueError(error_message) + + # Convert to numpy arrays + return [np.array(item, dtype=dtype) for item in inputs] + + def _check_and_preprocess_points( + self, + input_points=None, + input_labels=None, + input_boxes=None, + ): + r""" + Check and preprocesses the 2D points, labels and bounding boxes. It checks if the input is valid and if they + are, it converts the coordinates of the points and bounding boxes. If a user passes directly a `torch.Tensor`, + it is converted to a `numpy.ndarray` and then to a `list`. + """ + # Process each input type + input_points = self._preprocess_input(input_points, "Input points must be a list of list of floating points.") + + input_labels = self._preprocess_input(input_labels, "Input labels must be a list of list integers.") + + input_boxes = self._preprocess_input( + input_boxes, + "Input boxes must be a list of list of list of floating points.", + expected_nesting=2, + dtype=np.float32, + ) + + return input_points, input_labels, input_boxes + + @property + def model_input_names(self): + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(image_processor_input_names)) + + def post_process_masks(self, *args, **kwargs): + return self.image_processor.post_process_masks(*args, **kwargs) + + def _to_tensor(self, array, min_dim, return_tensors): + """ + Convert numpy array to tensor and ensure proper dimensionality. + Args: + array: The numpy array to convert + min_dim: The minimum number of dimensions the result should have + return_tensors: The type of tensors to return (e.g., "pt" for PyTorch tensors) + Returns: + The converted array or tensor with proper dimensions + """ + if return_tensors == "pt": + array = torch.from_numpy(array) + return array.unsqueeze(1) if array.ndim < min_dim else array + return array + + def _normalize_batch_coordinates(self, inputs, original_sizes, is_bounding_box=False): + """ + Normalize coordinates based on original sizes. + Args: + inputs: List of coordinate arrays + original_sizes: Original sizes of the images + is_bounding_box: Whether inputs are bounding boxes + Returns: + Normalized coordinates as list + """ + if len(original_sizes) != len(inputs): + # Use first original size for all inputs + return [ + self._normalize_coordinates(self.target_size, item, original_sizes[0], is_bounding_box=is_bounding_box) + for item in inputs + ] + else: + # Use paired original sizes for each input + return [ + self._normalize_coordinates(self.target_size, item, size, is_bounding_box=is_bounding_box) + for item, size in zip(inputs, original_sizes) + ] + + +__all__ = ["SamHQProcessor"] diff --git a/src/transformers/pipelines/mask_generation.py b/src/transformers/pipelines/mask_generation.py index 25bd611215..5c0c5e72c8 100644 --- a/src/transformers/pipelines/mask_generation.py +++ b/src/transformers/pipelines/mask_generation.py @@ -189,7 +189,17 @@ class MaskGenerationPipeline(ChunkPipeline): inference_context = self.get_inference_context() with inference_context(): model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device) - image_embeddings = self.model.get_image_embeddings(model_inputs.pop("pixel_values")) + embeddings = self.model.get_image_embeddings(model_inputs.pop("pixel_values")) + + # Handle both SAM (single tensor) and SAM-HQ (tuple) outputs + if isinstance(embeddings, tuple): + image_embeddings, intermediate_embeddings = embeddings + model_inputs["intermediate_embeddings"] = intermediate_embeddings + else: + image_embeddings = embeddings + # TODO: Identifying the model by the type of its returned embeddings is brittle. + # Consider using a more robust method for distinguishing model types here. + model_inputs["image_embeddings"] = image_embeddings n_points = grid_points.shape[1] diff --git a/tests/models/sam_hq/__init__.py b/tests/models/sam_hq/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/sam_hq/test_modeling_sam_hq.py b/tests/models/sam_hq/test_modeling_sam_hq.py new file mode 100644 index 0000000000..6c62d19746 --- /dev/null +++ b/tests/models/sam_hq/test_modeling_sam_hq.py @@ -0,0 +1,1116 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch SAM-HQ model.""" + +import tempfile +import unittest + +import requests + +from transformers import ( + SamHQConfig, + SamHQMaskDecoderConfig, + SamHQPromptEncoderConfig, + SamHQVisionConfig, + SamHQVisionModel, + pipeline, +) +from transformers.testing_utils import cleanup, require_torch, require_torch_sdpa, slow, torch_device +from transformers.utils import is_torch_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + from torch import nn + + from transformers import SamHQModel, SamHQProcessor + + +if is_vision_available(): + from PIL import Image + + +class SamHQVisionModelTester: + def __init__( + self, + parent, + hidden_size=36, + intermediate_size=72, + projection_dim=62, + output_channels=32, + num_hidden_layers=2, + num_attention_heads=4, + num_channels=3, + image_size=24, + patch_size=2, + hidden_act="gelu", + layer_norm_eps=1e-06, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + qkv_bias=True, + mlp_ratio=4.0, + use_abs_pos=True, + use_rel_pos=True, + rel_pos_zero_init=False, + window_size=14, + global_attn_indexes=[2, 5, 8, 11], + num_pos_feats=16, + mlp_dim=None, + batch_size=2, + ): + self.parent = parent + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.projection_dim = projection_dim + self.output_channels = output_channels + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.dropout = dropout + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.qkv_bias = qkv_bias + self.mlp_ratio = mlp_ratio + self.use_abs_pos = use_abs_pos + self.use_rel_pos = use_rel_pos + self.rel_pos_zero_init = rel_pos_zero_init + self.window_size = window_size + self.global_attn_indexes = global_attn_indexes + self.num_pos_feats = num_pos_feats + self.mlp_dim = mlp_dim + self.batch_size = batch_size + + # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.seq_length = num_patches + 1 + + def get_config(self): + return SamHQVisionConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + projection_dim=self.projection_dim, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + dropout=self.dropout, + attention_dropout=self.attention_dropout, + initializer_range=self.initializer_range, + initializer_factor=self.initializer_factor, + output_channels=self.output_channels, + qkv_bias=self.qkv_bias, + mlp_ratio=self.mlp_ratio, + use_abs_pos=self.use_abs_pos, + use_rel_pos=self.use_rel_pos, + rel_pos_zero_init=self.rel_pos_zero_init, + window_size=self.window_size, + global_attn_indexes=self.global_attn_indexes, + num_pos_feats=self.num_pos_feats, + mlp_dim=self.mlp_dim, + ) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + config = self.get_config() + + return config, pixel_values + + def create_and_check_model(self, config, pixel_values): + model = SamHQVisionModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(pixel_values) + output_size = self.image_size // self.patch_size + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, self.output_channels, output_size, output_size) + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class SamHQVisionModelTest(ModelTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as SAM's vision encoder does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (SamHQVisionModel,) if is_torch_available() else () + fx_compatible = False + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + test_torchscript = False + test_torch_exportable = True + + def setUp(self): + self.model_tester = SamHQVisionModelTester(self) + self.config_tester = ConfigTester(self, config_class=SamHQVisionConfig, has_text_modality=False) + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip(reason="SAM's vision encoder does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + def test_model_get_set_embeddings(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + expected_attention_shape = ( + self.model_tester.batch_size * self.model_tester.num_attention_heads, + 196, + 196, + ) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + attentions = outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(attentions[0].shape[-4:]), + list(expected_attention_shape), + ) + + @unittest.skip(reason="SamVisionModel does not support training") + def test_training(self): + pass + + @unittest.skip(reason="SamVisionModel does not support training") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="SamVisionModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="SamVisionModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_to_base(self): + pass + + @unittest.skip(reason="SamVisionModel does not support training") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="Hidden_states is tested in create_and_check_model tests") + def test_hidden_states_output(self): + pass + + @require_torch_sdpa + def test_sdpa_can_compile_dynamic(self): + self.skipTest(reason="SAM model can't be compiled dynamic yet") + + +class SamHQPromptEncoderTester: + def __init__( + self, + hidden_size=32, + input_image_size=24, + patch_size=2, + mask_input_channels=4, + num_point_embeddings=4, + hidden_act="gelu", + ): + self.hidden_size = hidden_size + self.input_image_size = input_image_size + self.patch_size = patch_size + self.mask_input_channels = mask_input_channels + self.num_point_embeddings = num_point_embeddings + self.hidden_act = hidden_act + + def get_config(self): + return SamHQPromptEncoderConfig( + image_size=self.input_image_size, + patch_size=self.patch_size, + mask_input_channels=self.mask_input_channels, + hidden_size=self.hidden_size, + num_point_embeddings=self.num_point_embeddings, + hidden_act=self.hidden_act, + ) + + def prepare_config_and_inputs(self): + dummy_points = floats_tensor([self.batch_size, 3, 2]) + config = self.get_config() + + return config, dummy_points + + +class SamHQMaskDecoderTester: + def __init__( + self, + hidden_size=32, + hidden_act="relu", + mlp_dim=64, + num_hidden_layers=12, + num_attention_heads=4, + attention_downsample_rate=2, + num_multimask_outputs=3, + iou_head_depth=3, + iou_head_hidden_dim=32, + layer_norm_eps=1e-6, + vit_dim=36, + ): + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.mlp_dim = mlp_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.attention_downsample_rate = attention_downsample_rate + self.num_multimask_outputs = num_multimask_outputs + self.iou_head_depth = iou_head_depth + self.iou_head_hidden_dim = iou_head_hidden_dim + self.layer_norm_eps = layer_norm_eps + self.vit_dim = vit_dim + + def get_config(self): + return SamHQMaskDecoderConfig( + hidden_size=self.hidden_size, + hidden_act=self.hidden_act, + mlp_dim=self.mlp_dim, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + attention_downsample_rate=self.attention_downsample_rate, + num_multimask_outputs=self.num_multimask_outputs, + iou_head_depth=self.iou_head_depth, + iou_head_hidden_dim=self.iou_head_hidden_dim, + layer_norm_eps=self.layer_norm_eps, + vit_dim=self.vit_dim, + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + + dummy_inputs = { + "image_embedding": floats_tensor([self.batch_size, self.hidden_size]), + } + return config, dummy_inputs + + +class SamHQModelTester: + def __init__( + self, + parent, + hidden_size=36, + intermediate_size=72, + projection_dim=62, + output_channels=32, + num_hidden_layers=12, + num_attention_heads=4, + num_channels=3, + image_size=24, + patch_size=2, + hidden_act="gelu", + layer_norm_eps=1e-06, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + qkv_bias=True, + mlp_ratio=4.0, + use_abs_pos=True, + use_rel_pos=True, + rel_pos_zero_init=False, + window_size=14, + global_attn_indexes=[2, 5, 8, 11], + num_pos_feats=16, + mlp_dim=None, + batch_size=2, + ): + self.parent = parent + self.image_size = image_size + self.patch_size = patch_size + self.output_channels = output_channels + self.num_channels = num_channels + self.hidden_size = hidden_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.dropout = dropout + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.qkv_bias = qkv_bias + self.mlp_ratio = mlp_ratio + self.use_abs_pos = use_abs_pos + self.use_rel_pos = use_rel_pos + self.rel_pos_zero_init = rel_pos_zero_init + self.window_size = window_size + self.global_attn_indexes = global_attn_indexes + self.num_pos_feats = num_pos_feats + self.mlp_dim = mlp_dim + self.batch_size = batch_size + + # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.seq_length = num_patches + 1 + + self.prompt_encoder_tester = SamHQPromptEncoderTester() + self.mask_decoder_tester = SamHQMaskDecoderTester() + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + config = self.get_config() + + return config, pixel_values + + def get_config(self): + vision_config = SamHQVisionConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + projection_dim=self.projection_dim, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + dropout=self.dropout, + attention_dropout=self.attention_dropout, + initializer_range=self.initializer_range, + initializer_factor=self.initializer_factor, + output_channels=self.output_channels, + qkv_bias=self.qkv_bias, + mlp_ratio=self.mlp_ratio, + use_abs_pos=self.use_abs_pos, + use_rel_pos=self.use_rel_pos, + rel_pos_zero_init=self.rel_pos_zero_init, + window_size=self.window_size, + global_attn_indexes=self.global_attn_indexes, + num_pos_feats=self.num_pos_feats, + mlp_dim=self.mlp_dim, + ) + + prompt_encoder_config = self.prompt_encoder_tester.get_config() + + mask_decoder_config = self.mask_decoder_tester.get_config() + + return SamHQConfig( + vision_config=vision_config, + prompt_encoder_config=prompt_encoder_config, + mask_decoder_config=mask_decoder_config, + ) + + def create_and_check_model(self, config, pixel_values): + model = SamHQModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + # Explicitly pass multimask_output=True + result = model(pixel_values, multimask_output=True) + + self.parent.assertEqual(result.iou_scores.shape, (self.batch_size, 1, 3)) + self.parent.assertEqual(result.pred_masks.shape[:3], (self.batch_size, 1, 3)) + + def create_and_check_get_image_features(self, config, pixel_values): + model = SamHQModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + image_embeddings = model.get_image_embeddings(pixel_values) + self.parent.assertEqual(image_embeddings[0][0].shape, (self.output_channels, 12, 12)) + + def create_and_check_get_image_and_intermediate_embeddings(self, config, pixel_values): + model = SamHQModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + image_embeddings, intermediate_embeddings = model.get_image_embeddings(pixel_values) + + self.parent.assertEqual(image_embeddings[0].shape, (self.output_channels, 12, 12)) + self.parent.assertEqual(intermediate_embeddings[0][0].shape, (12, 12, self.hidden_size)) + + def create_and_check_get_image_intermediate_embeddings(self, config, pixel_values): + model = SamHQModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + image_embeddings, intermediate_embeddings = model.get_image_embeddings(pixel_values) + + self.parent.assertIsInstance(intermediate_embeddings, list) + self.parent.assertTrue(len(intermediate_embeddings) > 0) + for embedding in intermediate_embeddings: + self.parent.assertEqual(embedding.shape, (self.batch_size, 12, 12, self.hidden_size)) + + def create_and_check_get_image_hidden_states(self, config, pixel_values): + model = SamHQModel(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model.vision_encoder( + pixel_values, + output_hidden_states=True, + return_dict=True, + ) + + # after computing the convolutional features + expected_hidden_states_shape = (self.batch_size, 12, 12, 36) + self.parent.assertEqual(len(result.hidden_states), self.num_hidden_layers + 1) + self.parent.assertEqual(result[-1][0].shape, expected_hidden_states_shape) + + with torch.no_grad(): + result = model.vision_encoder( + pixel_values, + output_hidden_states=True, + ) + + # after computing the convolutional features + expected_hidden_states_shape = (self.batch_size, 12, 12, 36) + self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1) + self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class SamHQModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as SAM-HQ's vision encoder does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (SamHQModel,) if is_torch_available() else () + pipeline_model_mapping = ( + {"feature-extraction": SamHQModel, "mask-generation": SamHQModel} if is_torch_available() else {} + ) + fx_compatible = False + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + test_torchscript = False + test_cpu_offload = False + test_disk_offload_bin = False + test_disk_offload_safetensors = False + + # TODO: Fix me @Arthur: `run_batch_test` in `tests/test_pipeline_mixin.py` not working + def is_pipeline_test_to_skip( + self, + pipeline_test_case_name, + config_class, + model_architecture, + tokenizer_name, + image_processor_name, + feature_extractor_name, + processor_name, + ): + return True + + def setUp(self): + self.model_tester = SamHQModelTester(self) + common_properties = ["initializer_range"] + self.config_tester = ConfigTester( + self, config_class=SamHQConfig, has_text_modality=False, common_properties=common_properties + ) + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip(reason="SAM-HQ's vision encoder does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Compile not yet supported in SamHQ models") + def test_sdpa_can_dispatch_on_flash(self): + pass + + def test_model_get_set_embeddings(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_get_image_features(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_get_image_features(*config_and_inputs) + + def test_get_image_and_intermediate_embeddings(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_get_image_and_intermediate_embeddings(*config_and_inputs) + + def test_get_image_intermediate_embeddings(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_get_image_intermediate_embeddings(*config_and_inputs) + + def test_image_hidden_states(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_get_image_hidden_states(*config_and_inputs) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + expected_vision_attention_shape = ( + self.model_tester.batch_size * self.model_tester.num_attention_heads, + 196, + 196, + ) + expected_mask_decoder_attention_shape = (self.model_tester.batch_size, 1, 144, 32) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + vision_attentions = outputs.vision_attentions + self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers) + + mask_decoder_attentions = outputs.mask_decoder_attentions + self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + vision_attentions = outputs.vision_attentions + self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers) + + mask_decoder_attentions = outputs.mask_decoder_attentions + self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers) + + self.assertListEqual( + list(vision_attentions[0].shape[-4:]), + list(expected_vision_attention_shape), + ) + + self.assertListEqual( + list(mask_decoder_attentions[0].shape[-4:]), + list(expected_mask_decoder_attention_shape), + ) + + @unittest.skip(reason="SamHQModel does not support training") + def test_training(self): + pass + + @unittest.skip(reason="SamHQModel does not support training") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="SamHQModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="SamHQModel has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_to_base(self): + pass + + @unittest.skip(reason="SamHQModel does not support training") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="Hidden_states is tested in create_and_check_model tests") + def test_hidden_states_output(self): + pass + + def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None): + # Use a slightly higher default tol to make the tests non-flaky + super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol=tol, name=name, attributes=attributes) + + @slow + def test_model_from_pretrained(self): + model_name = "sushmanth/sam_hq_vit_b" + model = SamHQModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + @require_torch_sdpa + def test_sdpa_can_compile_dynamic(self): + self.skipTest(reason="SamHQModel can't be compiled dynamic yet") + + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + """ + Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model. + This tests only by looking at layer names, as usually SDPA layers are calles "SDPAAttention". + In contrast to the above test, this one checks if the "config._attn_implamentation" is a dict after the model + is loaded, because we manually replicate requested attn implementation on each sub-config when loading. + See https://github.com/huggingface/transformers/pull/32238 for more info + + The test tries to cover most general cases of composite models, VLMs with vision and text configs. Any model + that has a different set of sub-configs has to overwrite this test. + """ + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self._is_composite: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa") + model_sdpa = model_sdpa.eval().to(torch_device) + + model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") + model_eager = model_eager.eval().to(torch_device) + + # Root model determines SDPA support + attn_impl = "sdpa" if model._supports_sdpa else "eager" + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + self.assertTrue(model_sdpa.vision_encoder.config._attn_implementation == attn_impl) + self.assertTrue(model_sdpa.mask_decoder.config._attn_implementation == attn_impl) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + self.assertTrue(model_eager.vision_encoder.config._attn_implementation == "eager") + self.assertTrue(model_eager.mask_decoder.config._attn_implementation == "eager") + + # Verify SDPA/eager layer presence + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + has_sdpa = True + break + + if not has_sdpa and attn_impl == "sdpa": + raise ValueError("The SDPA model should have SDPA attention layers") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + +def prepare_image(): + img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + return raw_image + + +def prepare_dog_img(): + img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + return raw_image + + +@slow +class SamHQModelIntegrationTest(unittest.TestCase): + def tearDown(self): + super().tearDown() + # clean-up as much as possible GPU memory occupied by PyTorch + cleanup(torch_device, gc_collect=True) + + def test_inference_mask_generation_no_point(self): + model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b") + processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + inputs = processor(images=raw_image, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores + + masks = outputs.pred_masks[0, 0, 0, 0, :3] + self.assertTrue(torch.allclose(scores[0][0][-1], torch.tensor(0.4482), atol=2e-4)) + self.assertTrue( + torch.allclose(masks, torch.tensor([-13.1695, -14.6201, -14.8989]).to(torch_device), atol=2e-3) + ) + + def test_inference_mask_generation_one_point_one_bb(self): + model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b") + processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + input_boxes = [[[650, 900, 1000, 1250]]] + input_points = [[[820, 1080]]] + + inputs = processor( + images=raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt" + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + masks = outputs.pred_masks[0, 0, 0, 0, :3] + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9700), atol=2e-4)) + self.assertTrue( + torch.allclose(masks, torch.tensor([-29.9144, -30.0546, -30.9526]).to(torch_device), atol=3e-2) + ) + + def test_inference_mask_generation_batched_points_batched_images(self): + model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b") + processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + input_points = [ + [[[820, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]], + [[[510, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]], + ] + + inputs = processor(images=[raw_image, raw_image], input_points=input_points, return_tensors="pt").to( + torch_device + ) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze().cpu() + masks = outputs.pred_masks[0, 0, 0, 0, :3].cpu() + EXPECTED_SCORES = torch.tensor( + [ + [ + [0.9195, 0.8316, 0.6614], + [0.9195, 0.8316, 0.6614], + [0.9195, 0.8316, 0.6614], + [0.9195, 0.8316, 0.6614], + ], + [ + [0.7598, 0.7388, 0.3110], + [0.9195, 0.8317, 0.6614], + [0.9195, 0.8317, 0.6614], + [0.9195, 0.8317, 0.6614], + ], + ] + ) + EXPECTED_MASKS = torch.tensor([-40.2445, -37.4300, -38.1577]) + + self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3)) + self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=9e-3)) + + def test_inference_mask_generation_one_point_one_bb_zero(self): + model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b") + processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + input_boxes = [[[620, 900, 1000, 1255]]] + input_points = [[[820, 1080]]] + labels = [[0]] + + inputs = processor( + images=raw_image, + input_boxes=input_boxes, + input_points=input_points, + input_labels=labels, + return_tensors="pt", + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.8680), atol=1e-3)) + + def test_inference_mask_generation_with_labels(self): + model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b") + processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + input_points = [[[400, 650]]] + input_labels = [[1]] + + inputs = processor( + images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + + scores = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9137), atol=1e-4)) + + def test_inference_mask_generation_without_labels(self): + model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b") + processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + input_points = [[[400, 650]]] + + inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + + scores = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9137), atol=1e-3)) + + def test_inference_mask_generation_two_points_with_labels(self): + model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b") + processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + input_points = [[[400, 650], [800, 650]]] + input_labels = [[1, 1]] + + inputs = processor( + images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + + scores = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.8859), atol=1e-3)) + + def test_inference_mask_generation_two_points_without_labels(self): + model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b") + processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + input_points = [[[400, 650], [800, 650]]] + + inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + + scores = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.8859), atol=1e-3)) + + def test_inference_mask_generation_two_points_batched(self): + model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b") + processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + input_points = [[[400, 650], [800, 650]], [[400, 650]]] + input_labels = [[1, 1], [1]] + + inputs = processor( + images=[raw_image, raw_image], + input_points=input_points, + input_labels=input_labels, + images_kwargs={"point_pad_value": -10}, + return_tensors="pt", + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores[0][-1], torch.tensor(0.4482), atol=1e-4)) + self.assertTrue(torch.allclose(scores[1][-1], torch.tensor(0.4482), atol=1e-4)) + + def test_inference_mask_generation_one_box(self): + model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b") + processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + input_boxes = [[[75, 275, 1725, 850]]] + + inputs = processor(images=raw_image, input_boxes=input_boxes, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.6265), atol=1e-4)) + + def test_inference_mask_generation_batched_image_one_point(self): + model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b") + processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + raw_dog_image = prepare_dog_img() + + input_points = [[[820, 1080]], [[220, 470]]] + + inputs = processor(images=[raw_image, raw_dog_image], input_points=input_points, return_tensors="pt").to( + torch_device + ) + + with torch.no_grad(): + outputs = model(**inputs) + scores_batched = outputs.iou_scores.squeeze() + + input_points = [[[220, 470]]] + + inputs = processor(images=raw_dog_image, input_points=input_points, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores_single = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores_batched[1, :], scores_single, atol=1e-4)) + + def test_inference_mask_generation_two_points_point_batch(self): + model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b") + processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + input_points = torch.Tensor([[[400, 650]], [[220, 470]]]).cpu() # fmt: skip + + input_points = input_points.unsqueeze(0) + + inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + + iou_scores = outputs.iou_scores.cpu() + self.assertTrue(iou_scores.shape == (1, 2, 3)) + torch.testing.assert_close( + iou_scores, torch.tensor([[[0.9889, 0.9508, 0.9137], [0.8070, 0.7934, 0.7932]]]), atol=1e-3, rtol=1e-3 + ) + + def test_inference_mask_generation_three_boxes_point_batch(self): + model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b") + processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + # fmt: off + input_boxes = torch.Tensor([[[620, 900, 1000, 1255]], [[75, 275, 1725, 850]], [[75, 275, 1725, 850]]]).cpu() + EXPECTED_IOU = torch.tensor([[[0.9850, 0.9730, 0.9726], + [0.8891, 0.8017, 0.6265], + [0.8891, 0.8017, 0.6265]]]) + # fmt: on + input_boxes = input_boxes.unsqueeze(0) + + inputs = processor(raw_image, input_boxes=input_boxes, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + + iou_scores = outputs.iou_scores.cpu() + self.assertTrue(iou_scores.shape == (1, 3, 3)) + torch.testing.assert_close(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4) + + def test_dummy_pipeline_generation(self): + generator = pipeline("mask-generation", model="sushmanth/sam_hq_vit_b", device=torch_device) + raw_image = prepare_image() + + _ = generator(raw_image, points_per_batch=64) diff --git a/tests/models/sam_hq/test_processor_samhq.py b/tests/models/sam_hq/test_processor_samhq.py new file mode 100644 index 0000000000..9f24cc7ab4 --- /dev/null +++ b/tests/models/sam_hq/test_processor_samhq.py @@ -0,0 +1,167 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import shutil +import tempfile +import unittest + +import numpy as np + +from transformers.testing_utils import require_torch, require_torchvision, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_processing_common import ProcessorTesterMixin, prepare_image_inputs + + +if is_vision_available(): + from PIL import Image + + from transformers import AutoProcessor, SamHQProcessor, SamImageProcessor + +if is_torch_available(): + import torch + + +@require_vision +@require_torchvision +class SamHQProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = SamHQProcessor + + @classmethod + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + image_processor = SamImageProcessor() + processor = SamHQProcessor(image_processor) + processor.save_pretrained(self.tmpdirname) + + def get_image_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + + @classmethod + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + # Processor tester class can't use ProcessorTesterMixin atm because the processor is atypical e.g. only contains an image processor + def prepare_image_inputs(self): + """This function prepares a list of PIL images.""" + return prepare_image_inputs() + + def prepare_mask_inputs(self): + """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, + or a list of PyTorch tensors if one specifies torchify=True. + """ + mask_inputs = [np.random.randint(255, size=(30, 400), dtype=np.uint8)] + mask_inputs = [Image.fromarray(x) for x in mask_inputs] + return mask_inputs + + def test_tokenizer_defaults_preserved_by_kwargs(self): + self.skipTest("SamHQProcessor does not have a tokenizer") + + def test_image_processor_defaults_preserved_by_image_kwargs(self): + self.skipTest("SamHQProcessor does not have a tokenizer") + + def test_chat_template_save_loading(self): + self.skipTest("SamHQProcessor does not have a tokenizer") + + def test_kwargs_overrides_default_image_processor_kwargs(self): + self.skipTest("SamHQProcessor does not have a tokenizer") + + def test_kwargs_overrides_default_tokenizer_kwargs(self): + self.skipTest("SamHQProcessor does not have a tokenizer") + + def test_unstructured_kwargs(self): + self.skipTest("SamHQProcessor does not have a tokenizer") + + def test_unstructured_kwargs_batched(self): + self.skipTest("SamHQProcessor does not have a tokenizer") + + def test_doubly_passed_kwargs(self): + self.skipTest("SamHQProcessor does not have a tokenizer") + + def test_structured_kwargs_nested(self): + self.skipTest("SamHQProcessor does not have a tokenizer") + + def test_structured_kwargs_nested_from_dict(self): + self.skipTest("SamHQProcessor does not have a tokenizer") + + def test_save_load_pretrained_additional_features(self): + self.skipTest("SamHQProcessor does not have a tokenizer") + + def test_image_processor_no_masks(self): + image_processor = self.get_image_processor() + + processor = SamHQProcessor(image_processor=image_processor) + + image_input = self.prepare_image_inputs() + + input_feat_extract = image_processor(image_input, return_tensors="pt") + input_processor = processor(images=image_input, return_tensors="pt") + + for key in input_feat_extract.keys(): + self.assertAlmostEqual(input_feat_extract[key].sum().item(), input_processor[key].sum().item(), delta=1e-2) + + for image in input_feat_extract.pixel_values: + self.assertEqual(image.shape, (3, 1024, 1024)) + + for original_size in input_feat_extract.original_sizes: + np.testing.assert_array_equal(original_size, np.array([30, 400])) + + for reshaped_input_size in input_feat_extract.reshaped_input_sizes: + np.testing.assert_array_equal( + reshaped_input_size, np.array([77, 1024]) + ) # reshaped_input_size value is before padding + + def test_image_processor_with_masks(self): + image_processor = self.get_image_processor() + + processor = SamHQProcessor(image_processor=image_processor) + + image_input = self.prepare_image_inputs() + mask_input = self.prepare_mask_inputs() + + input_feat_extract = image_processor(images=image_input, segmentation_maps=mask_input, return_tensors="pt") + input_processor = processor(images=image_input, segmentation_maps=mask_input, return_tensors="pt") + + for key in input_feat_extract.keys(): + self.assertAlmostEqual(input_feat_extract[key].sum().item(), input_processor[key].sum().item(), delta=1e-2) + + for label in input_feat_extract.labels: + self.assertEqual(label.shape, (256, 256)) + + @require_torch + def test_post_process_masks(self): + image_processor = self.get_image_processor() + + processor = SamHQProcessor(image_processor=image_processor) + dummy_masks = [torch.ones((1, 3, 5, 5))] + + original_sizes = [[1764, 2646]] + + reshaped_input_size = [[683, 1024]] + masks = processor.post_process_masks(dummy_masks, original_sizes, reshaped_input_size) + self.assertEqual(masks[0].shape, (1, 3, 1764, 2646)) + + masks = processor.post_process_masks( + dummy_masks, torch.tensor(original_sizes), torch.tensor(reshaped_input_size) + ) + self.assertEqual(masks[0].shape, (1, 3, 1764, 2646)) + + # should also work with np + dummy_masks = [np.ones((1, 3, 5, 5))] + masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size)) + + self.assertEqual(masks[0].shape, (1, 3, 1764, 2646)) + + dummy_masks = [[1, 0], [0, 1]] + with self.assertRaises(ValueError): + masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size)) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 9865067bcd..d6cb3f42c1 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -105,6 +105,8 @@ SPECIAL_CASES_TO_ALLOW = { "AutoformerConfig": ["num_static_real_features", "num_time_features"], # used internally to calculate `mlp_dim` "SamVisionConfig": ["mlp_ratio"], + # used internally to calculate `mlp_dim` + "SamHQVisionConfig": ["mlp_ratio"], # For (head) training, but so far not implemented "ClapAudioConfig": ["num_classes"], # Not used, but providing useful information to users diff --git a/utils/check_copies.py b/utils/check_copies.py index 9b392f1367..e37894c4ea 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -1040,6 +1040,7 @@ SPECIAL_MODEL_NAMES = { "OpenAI GPT": "GPT", "Perceiver": "Perceiver IO", "SAM": "Segment Anything", + "SAM_HQ": "Segment Anything High Quality", "ViT": "Vision Transformer (ViT)", } diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index 09f21e0e2f..154eda0cf1 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -469,6 +469,8 @@ OBJECTS_TO_IGNORE = [ "SEWForCTC", "SamConfig", "SamPromptEncoderConfig", + "SamHQConfig", + "SamHQPromptEncoderConfig", "SeamlessM4TConfig", # use of unconventional markdown "SeamlessM4Tv2Config", # use of unconventional markdown "Seq2SeqTrainingArguments", diff --git a/utils/check_repo.py b/utils/check_repo.py index 4681e5b45c..1164ac9db0 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -235,6 +235,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ "JukeboxVQVAE", "JukeboxPrior", "SamModel", + "SamHQModel", "DPTForDepthEstimation", "DecisionTransformerGPT2Model", "GLPNForDepthEstimation", diff --git a/utils/not_doctested.txt b/utils/not_doctested.txt index e246a1bc4d..4dfcaca525 100644 --- a/utils/not_doctested.txt +++ b/utils/not_doctested.txt @@ -209,6 +209,7 @@ docs/source/en/model_doc/roc_bert.md docs/source/en/model_doc/roformer.md docs/source/en/model_doc/rwkv.md docs/source/en/model_doc/sam.md +docs/source/en/model_doc/sam_hq.md docs/source/en/model_doc/segformer.md docs/source/en/model_doc/sew-d.md docs/source/en/model_doc/sew.md