From 487dab1b2b2d0c4e3ca32ec713766cc25673006e Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Thu, 20 Mar 2025 10:14:38 -0400 Subject: [PATCH] Shieldgemma2 (#36678) * single commit * correct config * fixup * dummy pt * Use ShieldGemma2Config in conversion script * Update src/transformers/models/shieldgemma2/configuration_shieldgemma2.py * Adding shieldgemma2 to models.__init__.py * Adding ShieldGemma2 to main __init__.py * Update shieldgemma2.md * Update shieldgemma2.md * Adding tests. Addressing review feedback. * Minor docs update * Fixing code quality feedback from CI * Fixing empty messages bug reported by ghunkins --------- Co-authored-by: Arthur Zucker Co-authored-by: Ren Pang --- docs/source/en/model_doc/shieldgemma2.md | 100 ++++ src/transformers/__init__.py | 12 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + .../models/auto/image_processing_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 2 + .../models/auto/processing_auto.py | 1 + .../models/auto/tokenization_auto.py | 7 + .../models/shieldgemma2/__init__.py | 28 ++ .../configuration_shieldgemma2.py | 115 +++++ ...onvert_shieldgemma2_weights_orbax_to_hf.py | 470 ++++++++++++++++++ .../shieldgemma2/modeling_shieldgemma2.py | 228 +++++++++ .../shieldgemma2/processing_shieldgemma2.py | 195 ++++++++ src/transformers/utils/dummy_pt_objects.py | 7 + tests/models/shieldgemma2/__init__.py | 0 .../test_modeling_shieldgemma2.py | 61 +++ .../test_processing_shieldgemma2.py | 220 ++++++++ utils/check_config_attributes.py | 8 + utils/check_repo.py | 1 + 19 files changed, 1459 insertions(+) create mode 100644 docs/source/en/model_doc/shieldgemma2.md create mode 100644 src/transformers/models/shieldgemma2/__init__.py create mode 100644 src/transformers/models/shieldgemma2/configuration_shieldgemma2.py create mode 100644 src/transformers/models/shieldgemma2/convert_shieldgemma2_weights_orbax_to_hf.py create mode 100644 src/transformers/models/shieldgemma2/modeling_shieldgemma2.py create mode 100644 src/transformers/models/shieldgemma2/processing_shieldgemma2.py create mode 100644 tests/models/shieldgemma2/__init__.py create mode 100644 tests/models/shieldgemma2/test_modeling_shieldgemma2.py create mode 100644 tests/models/shieldgemma2/test_processing_shieldgemma2.py diff --git a/docs/source/en/model_doc/shieldgemma2.md b/docs/source/en/model_doc/shieldgemma2.md new file mode 100644 index 0000000000..016fe9bf98 --- /dev/null +++ b/docs/source/en/model_doc/shieldgemma2.md @@ -0,0 +1,100 @@ + + + +# ShieldGemma 2 + +## Overview + +The ShieldGemma 2 model was proposed in a forthcoming technical report by Google. ShieldGemma 2 is built on [Gemma 3](https://ai.google.dev/gemma/docs/core/model_card_3), is a 4 billion (4B) parameter model that checks the safety of both synthetic and natural images against key categories to help you build robust datasets and models. With this addition to the Gemma family of models, researchers and developers can now easily minimize the risk of harmful content in their models across key areas of harm as defined below: + +- No Sexually Explicit content: The image shall not contain content that depicts explicit or graphic sexual acts (e.g., pornography, erotic nudity, depictions of rape or sexual assault). +- No Dangerous Content: The image shall not contain content that facilitates or encourages activities that could cause real-world harm (e.g., building firearms and explosive devices, promotion of terrorism, instructions for suicide). +- No Violence/Gore content: The image shall not contain content that depicts shocking, sensational, or gratuitous violence (e.g., excessive blood and gore, gratuitous violence against animals, extreme injury or moment of death). + +We recommend using ShieldGemma 2 as an input filter to vision language models, or as an output filter of image generation systems. To train a robust image safety model, we curated training datasets of natural and synthetic images and instruction-tuned Gemma 3 to demonstrate strong performance. + +This model was contributed by [Ryan Mullins](https://huggingface.co/RyanMullins). + +## Usage Example + +- ShieldGemma 2 provides a Processor that accepts a list of `images` and an optional list of `policies` as input, and constructs a batch of prompts as the product of these two lists using the provided chat template. +- You can extend ShieldGemma's built-in in policies with the `custom_policies` argument to the Processor. Using the same key as one of the built-in policies will overwrite that policy with your custom defintion. +- ShieldGemma 2 does not support the image cropping capabilities used by Gemma 3. + +### Classification against Built-in Policies + +```python +from PIL import Image +import requests +from transformers import AutoProcessor, ShieldGemma2ForImageClassification + +model_id = "google/shieldgemma-2-4b-it" +model = ShieldGemma2ForImageClassification.from_pretrained(model_id, device_map="auto") +processor = AutoProcessor.from_pretrained(model_id) + +url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg" +image = Image.open(requests.get(url, stream=True).raw) + +inputs = processor(images=[image], return_tensors="pt").to(model.device) + +output = model(**inputs) +print(output.probabilities) +``` + +### Classification against Custom Policies + +```python +from PIL import Image +import requests +from transformers import AutoProcessor, ShieldGemma2ForImageClassification + +model_id = "google/shieldgemma-2-4b-it" +model = ShieldGemma2ForImageClassification.from_pretrained(model_id, device_map="auto") +processor = AutoProcessor.from_pretrained(model_id) + +url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg" +image = Image.open(requests.get(url, stream=True).raw) + +custom_policies = { + "key_a": "descrition_a", + "key_b": "descrition_b", +} + +inputs = processor( + images=[image], + custom_policies=custom_policies, + policies=["dangerous", "key_a", "key_b"], + return_tensors="pt", +).to(model.device) + +output = model(**inputs) +print(output.probabilities) +``` + + +## ShieldGemma2Processor + +[[autodoc]] ShieldGemma2Processor + +## ShieldGemma2Config + +[[autodoc]] ShieldGemma2Config + +## ShieldGemma2ForImageClassification + +[[autodoc]] ShieldGemma2ForImageClassification + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 1e33af79ab..56960e9b5f 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -774,6 +774,10 @@ _import_structure = { "models.seggpt": ["SegGptConfig"], "models.sew": ["SEWConfig"], "models.sew_d": ["SEWDConfig"], + "models.shieldgemma2": [ + "ShieldGemma2Config", + "ShieldGemma2Processor", + ], "models.siglip": [ "SiglipConfig", "SiglipProcessor", @@ -3581,6 +3585,7 @@ else: "SEWDPreTrainedModel", ] ) + _import_structure["models.shieldgemma2"].append("ShieldGemma2ForImageClassification") _import_structure["models.siglip"].extend( [ "SiglipForImageClassification", @@ -5982,6 +5987,10 @@ if TYPE_CHECKING: from .models.seggpt import SegGptConfig from .models.sew import SEWConfig from .models.sew_d import SEWDConfig + from .models.shieldgemma2 import ( + ShieldGemma2Config, + ShieldGemma2Processor, + ) from .models.siglip import ( SiglipConfig, SiglipProcessor, @@ -8350,6 +8359,9 @@ if TYPE_CHECKING: SEWDModel, SEWDPreTrainedModel, ) + from .models.shieldgemma2 import ( + ShieldGemma2ForImageClassification, + ) from .models.siglip import ( SiglipForImageClassification, SiglipModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 3a72fca91e..22a0d281b3 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -247,6 +247,7 @@ from . import ( seggpt, sew, sew_d, + shieldgemma2, siglip, siglip2, smolvlm, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 0969976937..3b7445f604 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -274,6 +274,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("seggpt", "SegGptConfig"), ("sew", "SEWConfig"), ("sew-d", "SEWDConfig"), + ("shieldgemma2", "ShieldGemma2Config"), ("siglip", "SiglipConfig"), ("siglip2", "Siglip2Config"), ("siglip_vision_model", "SiglipVisionConfig"), @@ -625,6 +626,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("seggpt", "SegGPT"), ("sew", "SEW"), ("sew-d", "SEW-D"), + ("shieldgemma2", "Shieldgemma2"), ("siglip", "SigLIP"), ("siglip2", "SigLIP2"), ("siglip2_vision_model", "Siglip2VisionModel"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 4dfa2e8645..9a5edd4835 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -137,6 +137,7 @@ else: ("sam", ("SamImageProcessor",)), ("segformer", ("SegformerImageProcessor",)), ("seggpt", ("SegGptImageProcessor",)), + ("shieldgemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")), ("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")), ("siglip2", ("Siglip2ImageProcessor", "Siglip2ImageProcessorFast")), ("superglue", "SuperGlueImageProcessor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index f08fc2fa04..649b56a214 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -727,6 +727,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("regnet", "RegNetForImageClassification"), ("resnet", "ResNetForImageClassification"), ("segformer", "SegformerForImageClassification"), + ("shieldgemma2", "ShieldGemma2ForImageClassification"), ("siglip", "SiglipForImageClassification"), ("siglip2", "Siglip2ForImageClassification"), ("swiftformer", "SwiftFormerForImageClassification"), @@ -849,6 +850,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict( ("pixtral", "LlavaForConditionalGeneration"), ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), ("qwen2_vl", "Qwen2VLForConditionalGeneration"), + ("shieldgemma2", "Gemma3ForConditionalGeneration"), ("smolvlm", "SmolVLMForConditionalGeneration"), ("udop", "UdopForConditionalGeneration"), ("vipllava", "VipLlavaForConditionalGeneration"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index a318d443fb..5b699a4a44 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -101,6 +101,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict( ("seamless_m4t", "SeamlessM4TProcessor"), ("sew", "Wav2Vec2Processor"), ("sew-d", "Wav2Vec2Processor"), + ("shieldgemma2", "ShieldGemma2Processor"), ("siglip", "SiglipProcessor"), ("siglip2", "Siglip2Processor"), ("speech_to_text", "Speech2TextProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 5c198fe4af..13f9a8a429 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -493,6 +493,13 @@ else: "SeamlessM4TTokenizerFast" if is_tokenizers_available() else None, ), ), + ( + "shieldgemma2", + ( + "GemmaTokenizer" if is_sentencepiece_available() else None, + "GemmaTokenizerFast" if is_tokenizers_available() else None, + ), + ), ("siglip", ("SiglipTokenizer" if is_sentencepiece_available() else None, None)), ( "siglip2", diff --git a/src/transformers/models/shieldgemma2/__init__.py b/src/transformers/models/shieldgemma2/__init__.py new file mode 100644 index 0000000000..3eaa894027 --- /dev/null +++ b/src/transformers/models/shieldgemma2/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_shieldgemma2 import * + from .modeling_shieldgemma2 import * + from .processing_shieldgemma2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/shieldgemma2/configuration_shieldgemma2.py b/src/transformers/models/shieldgemma2/configuration_shieldgemma2.py new file mode 100644 index 0000000000..8094cb14b4 --- /dev/null +++ b/src/transformers/models/shieldgemma2/configuration_shieldgemma2.py @@ -0,0 +1,115 @@ +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING, AutoConfig + + +logger = logging.get_logger(__name__) + + +class ShieldGemma2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ShieldGemma2ForImageClassification`]. It is used to instantiate an + ShieldGemma2ForImageClassification according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the shieldgemma-2-4b-it. + + e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`Union[ShieldGemma2TextConfig, dict]`, *optional*): + The config object of the text backbone. + vision_config (`Union[AutoConfig, dict]`, *optional*): + Custom vision config or dict. + mm_tokens_per_image (`int`, *optional*, defaults to 256): + The number of tokens per image embedding. + boi_token_index (`int`, *optional*, defaults to 255999): + The begin-of-image token index to wrap the image prompt. + eoi_token_index (`int`, *optional*, defaults to 256000): + The end-of-image token index to wrap the image prompt. + image_token_index (`int`, *optional*, defaults to 262144): + The image token index to encode the image prompt. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + + Example: + + ```python + >>> from transformers import ShieldGemma2ForConditionalGeneration, ShieldGemma2Config, SiglipVisionConfig, ShieldGemma2TextConfig + + >>> # Initializing a Siglip-like vision config + >>> vision_config = SiglipVisionConfig() + + >>> # Initializing a ShieldGemma2 Text config + >>> text_config = ShieldGemma2TextConfig() + + >>> # Initializing a ShieldGemma2 gemma-3-4b style configuration + >>> configuration = ShieldGemma2Config(vision_config, text_config) + + >>> # Initializing a model from the gemma-3-4b style configuration + >>> model = ShieldGemma2TextConfig(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "shieldgemma2" + sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} + + def __init__( + self, + text_config=None, + vision_config=None, + mm_tokens_per_image: int = 256, + boi_token_index: int = 255_999, + eoi_token_index: int = 256_000, + image_token_index: int = 262_144, + initializer_range: float = 0.02, + **kwargs, + ): + if isinstance(vision_config, dict): + vision_config["model_type"] = ( + vision_config["model_type"] if "model_type" in vision_config else "siglip_vision_model" + ) + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["siglip_vision_model"]() + + self.vision_config = vision_config + + if isinstance(text_config, dict): + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma3_text" + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["gemma3_text"]() + + self.text_config = text_config + self.vision_config = vision_config + self.mm_tokens_per_image = mm_tokens_per_image + self.boi_token_index = boi_token_index + self.eoi_token_index = eoi_token_index + self.image_token_index = image_token_index + self.initializer_range = initializer_range + + super().__init__(**kwargs) + + +__all__ = ["ShieldGemma2Config"] diff --git a/src/transformers/models/shieldgemma2/convert_shieldgemma2_weights_orbax_to_hf.py b/src/transformers/models/shieldgemma2/convert_shieldgemma2_weights_orbax_to_hf.py new file mode 100644 index 0000000000..0ca46eb29e --- /dev/null +++ b/src/transformers/models/shieldgemma2/convert_shieldgemma2_weights_orbax_to_hf.py @@ -0,0 +1,470 @@ +r"""Utility to convert Gemma models from Orbax to HF Transformers checkpoint. + +python -m transformers.models.shieldgemma2.convert_shieldgemma2_weights_orbax_to_hf \ + --tokenizer_path="$HOME/gemma3/tokenizer/gemma3_cleaned_262144_v2.spiece.model" \ + --checkpoint_path_gemma="$HOME/gemma3/gemma3_4b_pt_orbax/" \ + --checkpoint_path_shieldgemma="$HOME/shieldgemma2/shieldgemma-2_4b_orbax/" \ + --output_path="$HOME/shieldgemma2/shieldgemma2_4b_pt_safetensors/" \ + --precision='bfloat16' +""" + +import dataclasses +from collections.abc import Iterator, Mapping, Sequence +from typing import Any + +import accelerate +import numpy as np +import torch +import tree +from absl import app, flags, logging +from orbax import checkpoint as obc + +from ...image_utils import PILImageResampling +from ..gemma import GemmaTokenizerFast +from ..gemma3 import Gemma3ImageProcessor, Gemma3TextConfig +from ..siglip import SiglipVisionConfig +from . import ShieldGemma2Config, ShieldGemma2ForImageClassification, ShieldGemma2Processor + + +# ==== Internal Constants and Classes ==== + +_CHAT_TEMPLATE = """{{ bos_token }} +{%- for message in messages -%} + {%- if message['content'] is iterable -%} + {%- for item in message['content'] -%} + {{ 'user\n' }} + {%- if item['type'] == 'image' -%} + {{ '' }} + {%- elif item['type'] == 'text' -%} + {{ 'You are a policy expert trying to help determine whether a generated or real image is in' }} + {{ ' violation of the safety policy below. Your decision must very clearly follow the specific'}} + {{ ' language used in the safety policy and should not include any interpretation outside of the' }} + {{ ' safety policy.\n\n' }} + {{ '\n' }} + {{ item['text'] + '\n' }} + {{ '\n\n' }} + {{ 'Does the image violate the above policy? Your Answer must start with \\'Yes\\' or \\'No\\'.' }} + {{ '\n' }} + {%- endif -%} + {%- endfor -%} + {{'model\n'}} + {%- else -%} + {{ raise_exception("Conversation messages must contain iterable content containing images and policy definitions in text.") }} + {%- endif -%} +{%- endfor -%} +""" + +_DTYPES = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, + "float16": torch.float16, +} + +_SHIELDGEMMA2_POLICIES: Mapping[str, str] = { + "dangerous": ( + "No Dangerous Content: The image shall not contain content that" + " facilitates or encourages activities that could cause real-world harm" + " (e.g., building firearms and explosive devices, promotion of" + " terrorism, instructions for suicide)." + ), + "sexual": ( + "No Sexually Explicit content: The image shall not contain content that" + " depicts explicit or graphic sexual acts (e.g., pornography, erotic" + " nudity, depictions of rape or sexual assault)." + ), + "violence": ( + "No Violence/Gore content: The image shall not contain content that" + " depicts shocking, sensational, or gratuitous violence (e.g.," + " excessive blood and gore, gratuitous violence against animals," + " extreme injury or moment of death)." + ), +} + +_SIGLIP_BASE = "SigLiPFromPatches_0/siglip_encoder" +_SIGLIP_EMBEDDING = "SigLiPFromPatches_0/siglip_encoder/embedding" +_SIGLIP_TRANSFORMER_ENCODER_BLOCK = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoderblock_" +_SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN = len(_SIGLIP_TRANSFORMER_ENCODER_BLOCK) +_SIGLIP_TRANSFORMER_ENCODER_NORM = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoder_norm" + +_TRANSFORMER_DECODER_BLOCK = "transformer/layer_" +_TRANSFORMER_DECODER_BLOCK_LEN = len(_TRANSFORMER_DECODER_BLOCK) +_TRANSFORMER_EMBEDDER = "transformer/embedder" +_TRANSFORMER_FINAL_NORM = "transformer/final_norm" +_TRANSFORMER_POST_TRAINING_PREFIX = "rlx_networks/policy_network/" +_TRANSFORMER_POST_TRAINING_PREFIX_LEN = len(_TRANSFORMER_POST_TRAINING_PREFIX) + +# ==== Flags ==== + +_GEMMA_CHECKPOINT_PATH = flags.DEFINE_string( + name="checkpoint_path_gemma", + default=None, + help="Path to the Orbax checkpoint containing the vision weights.", + required=True, +) + +_SHIELDGEMMA_CHECKPOINT_PATH = flags.DEFINE_string( + name="checkpoint_path_shieldgemma", + default=None, + help="Path to the Orbax checkpoint containing the language model weights.", + required=True, +) + +OUTPUT_PATH = flags.DEFINE_string( + name="output_path", + default=None, + help="Path to store the HF checkpoint.", + required=True, +) + +PRECISION = flags.DEFINE_enum( + name="precision", + default=None, + help="The floating point precision (aka dtype) of the model.", + enum_values=set(_DTYPES.keys()), + required=True, +) + +TOKENIZER_PATH = flags.DEFINE_string( + name="tokenizer_path", + default=None, + help="Path to the SentencePiece model file.", + required=True, +) + + +def convert_siglip_weight( + config: SiglipVisionConfig, + paths: Sequence[str], + weights: np.ndarray, +) -> tuple[str, np.ndarray]: + path, prop = paths + normalized_path: str = "" + updated_weights: np.ndarray = None + + if path == _SIGLIP_BASE: + normalized_path = "vision_tower.vision_model.embeddings.position_embedding.weight" + updated_weights = weights.reshape(-1, config.hidden_size) + elif path == _SIGLIP_EMBEDDING: + if prop == "kernel": + normalized_path = "vision_tower.vision_model.embeddings.patch_embedding.weight" + updated_weights = weights.transpose(3, 2, 0, 1) + elif prop == "bias": + normalized_path = "vision_tower.vision_model.embeddings.patch_embedding.bias" + updated_weights = weights + else: + raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.") + elif path.startswith(_SIGLIP_TRANSFORMER_ENCODER_BLOCK): + encoder_block_path = path[_SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN:] + next_path_seperator_idx = encoder_block_path.find("/") + layer_idx = encoder_block_path[:next_path_seperator_idx] + encoder_block_path = encoder_block_path[next_path_seperator_idx:] + normalized_path = f"vision_tower.vision_model.encoder.layers.{layer_idx}" + + if encoder_block_path.startswith("/LayerNorm"): + normalized_path += ".layer_norm1" if path.endswith("_0") else ".layer_norm2" + + if prop == "scale": + normalized_path += ".weight" + updated_weights = weights.transpose() + elif prop == "bias": + normalized_path += ".bias" + updated_weights = weights + else: + raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `scale`.") + elif encoder_block_path.startswith("/MlpBlock_0"): + normalized_path += ".mlp.fc1" if "/Dense_0" in encoder_block_path else ".mlp.fc2" + + if prop == "kernel": + normalized_path += ".weight" + updated_weights = weights.transpose() + elif prop == "bias": + normalized_path += ".bias" + updated_weights = weights + else: + raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.") + elif encoder_block_path.startswith("/MultiHeadDotProductAttention_0"): + if encoder_block_path.endswith("/key"): + normalized_path += ".self_attn.k_proj" + elif encoder_block_path.endswith("/out"): + normalized_path += ".self_attn.out_proj" + elif encoder_block_path.endswith("/query"): + normalized_path += ".self_attn.q_proj" + elif encoder_block_path.endswith("/value"): + normalized_path += ".self_attn.v_proj" + else: + raise ValueError(f"Unexpected path `{path}` in SigLIP Transformer MultiHeadDotProductAttention_0.") + + if prop == "bias": + normalized_path += ".bias" + updated_weights = weights.reshape(-1, config.hidden_size).reshape(-1) + elif prop == "kernel": + normalized_path += ".weight" + updated_weights = weights.reshape(-1, config.hidden_size).transpose() + else: + raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.") + else: + raise ValueError(f"Unexpected path `{path}` in SigLIP Transformer Encoder Block.") + elif path == _SIGLIP_TRANSFORMER_ENCODER_NORM: + if prop == "scale": + normalized_path = "vision_tower.vision_model.post_layernorm.weight" + updated_weights = weights.transpose() + elif prop == "bias": + normalized_path = "vision_tower.vision_model.post_layernorm.bias" + updated_weights = weights + else: + raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `scale`.") + else: + raise ValueError(f"Unexpected path `{path}`.") + + if "vision" in normalized_path: + print(normalized_path) + return normalized_path, updated_weights + + +def convert_transformer_weights( + config: Gemma3TextConfig, + paths: Sequence[str], + weights: np.ndarray, +) -> Iterator[tuple[str, np.ndarray]]: + path, prop = paths + + if path.startswith(_TRANSFORMER_POST_TRAINING_PREFIX): + path = path[_TRANSFORMER_POST_TRAINING_PREFIX_LEN:] + + converted_paths: list[str] = [] + converted_weights: list[Any] = [] + + attn_head_dim = config.num_attention_heads * config.head_dim + kv_head_dim = config.num_key_value_heads * config.head_dim + + if path == _TRANSFORMER_EMBEDDER: + if prop == "input_embedding": + # Tied to language_model.lm_head.weight, assigned at the end. + converted_paths = ["language_model.model.embed_tokens.weight"] + # Gemma3 model doesn't have image soft token in input and output embeddings, resize to avoid bugs we had with Mllama + pre_expansion_embeddings = weights + mu = np.mean(pre_expansion_embeddings, axis=0) + sigma = np.cov(pre_expansion_embeddings, rowvar=False, bias=True) + new_embeddings = np.random.multivariate_normal(mu, 1e-5 * sigma, size=64) + weights = np.vstack([pre_expansion_embeddings, new_embeddings]) + converted_weights = [weights] + else: + raise ValueError(f"Unexpected member, {prop}, in Embedder.") + elif path.startswith(f"{_TRANSFORMER_EMBEDDER}/mm"): + if path.endswith("/mm_input_projection"): + converted_paths = ["multi_modal_projector.mm_input_projection_weight"] + converted_weights = [weights] + elif path.endswith("/mm_soft_embedding_norm"): + converted_paths = ["multi_modal_projector.mm_soft_emb_norm.weight"] + converted_weights = [weights] + else: + raise ValueError(f"Unexpected subpath, `{path}`, in Embedder.") + elif path == _TRANSFORMER_FINAL_NORM: + converted_paths = ["language_model.model.norm.weight"] + converted_weights = [weights] + elif path.startswith(_TRANSFORMER_DECODER_BLOCK): + decoder_block_path = path[_TRANSFORMER_DECODER_BLOCK_LEN:] + next_path_seperator_idx = decoder_block_path.find("/") + layer_idx = decoder_block_path[:next_path_seperator_idx] + decoder_block_path = decoder_block_path[next_path_seperator_idx:] + + base_path = f"language_model.model.layers.{layer_idx}" + + if path.endswith("attn/attn_vec_einsum"): + converted_paths = [f"{base_path}.self_attn.o_proj.weight"] + converted_weights = [weights.transpose(2, 0, 1).reshape(config.hidden_size, attn_head_dim)] + elif path.endswith("attn/_key_norm"): + converted_paths = [f"{base_path}.self_attn.k_norm.weight"] + converted_weights = [weights] + elif path.endswith("attn/kv_einsum"): + converted_paths = [ + f"{base_path}.self_attn.k_proj.weight", + f"{base_path}.self_attn.v_proj.weight", + ] + k_proj_weights, v_proj_weights = weights + converted_weights = [ + k_proj_weights.transpose(0, 2, 1).reshape(kv_head_dim, config.hidden_size), + v_proj_weights.transpose(0, 2, 1).reshape(kv_head_dim, config.hidden_size), + ] + elif path.endswith("attn/q_einsum"): + converted_paths = [f"{base_path}.self_attn.q_proj.weight"] + converted_weights = [weights.transpose(0, 2, 1).reshape(attn_head_dim, config.hidden_size)] + elif path.endswith("attn/_query_norm"): + converted_paths = [f"{base_path}.self_attn.q_norm.weight"] + converted_weights = [weights] + elif path.endswith("mlp/gating_einsum"): + converted_paths = [ + f"{base_path}.mlp.gate_proj.weight", + f"{base_path}.mlp.up_proj.weight", + ] + gate_proj_weight, up_proj_weight = weights + converted_weights = [gate_proj_weight, up_proj_weight] + elif path.endswith("mlp/linear"): + converted_paths = [f"{base_path}.mlp.down_proj.weight"] + converted_weights = [weights.transpose()] + elif path.endswith("post_attention_norm"): + converted_paths = [f"{base_path}.post_attention_layernorm.weight"] + converted_weights = [weights] + elif path.endswith("post_ffw_norm"): + converted_paths = [f"{base_path}.post_feedforward_layernorm.weight"] + converted_weights = [weights] + elif path.endswith("pre_attention_norm"): + converted_paths = [f"{base_path}.input_layernorm.weight"] + converted_weights = [weights] + elif path.endswith("pre_ffw_norm"): + converted_paths = [f"{base_path}.pre_feedforward_layernorm.weight"] + converted_weights = [weights] + else: + raise ValueError(f"Unexpected path `{path}` in Decoder Block.") + else: + raise ValueError(f"Unexpected path `{path}`.") + + if (cpl := len(converted_paths)) != (cwl := len(converted_weights)): + raise ValueError( + "The `converted_paths` and `converted_weights` should be the same " + f"length. Got {cpl} and {cwl}, respectively, for {path}." + ) + + return zip(converted_paths, converted_weights) + + +def transpose_reshape(x: torch.Tensor) -> torch.Tensor: + x = x.transpose(1, 2) + return x.reshape(x.shape[0] * x.shape[1], x.shape[2]).contiguous() + + +@dataclasses.dataclass(frozen=True) +class ConversionResult: + state_tree: dict[str, torch.Tensor] + config: ShieldGemma2Config + + +def convert( + shieldgemma_checkpoint_path: str, + gemma_checkpoint_path: str, + config: ShieldGemma2Config, + target_dtype: torch.dtype, +) -> ConversionResult: + """Loads Orbax checkpoint from `input_path` and converts it to HF tree.""" + checkpointer = obc.PyTreeCheckpointer() + + sg2_ckpt = checkpointer.restore(shieldgemma_checkpoint_path) + g3_ckpt = checkpointer.restore(gemma_checkpoint_path) + + hf_tree: dict[str, torch.Tensor] = {} + + def update_tree(path: str, weights: np.ndarray) -> None: + torch_tensor = torch.from_numpy(weights.astype("float32")).type(target_dtype) + logging.info( + "%s converted shape=%s with dtype=%s", + path, + weights.shape, + torch_tensor.dtype, + ) + hf_tree[f"model.{path}"] = torch_tensor + + for paths, value in tree.flatten_with_path(g3_ckpt): + if paths[0].startswith("SigLiPFromPatches_"): + path, weights = convert_siglip_weight(config=config.vision_config, paths=paths, weights=value) + update_tree(path, weights) + + for paths, value in tree.flatten_with_path(sg2_ckpt): + for path, weights in convert_transformer_weights(config=config.text_config, paths=paths, weights=value): + update_tree(path, weights) + + hf_tree["model.language_model.lm_head.weight"] = hf_tree["model.language_model.model.embed_tokens.weight"] + + return ConversionResult(state_tree=hf_tree, config=config) + + +def main(*args): + del args + + dtype = getattr(torch, PRECISION.value) + output_path = OUTPUT_PATH.value + + tokenizer = GemmaTokenizerFast( + TOKENIZER_PATH.value, + extra_special_tokens={ + "image_token": "", # Should be ID=262_144 + "boi_token": "", # Should be ID=255_999 + "eoi_token": "", # Should be ID=256_000 + }, + ) + + yes_token_index, no_token_index = torch.tensor(tokenizer(["Yes", "No"])["input_ids"])[:, 1].numpy() + + config = ShieldGemma2Config( + yes_token_index=int(yes_token_index), + no_token_index=int(no_token_index), + text_config=Gemma3TextConfig( + vocab_size=262_208, + hidden_size=2560, + intermediate_size=2560 * 8 // 2, + num_attention_heads=8, + head_dim=256, + num_hidden_layers=34, + num_key_value_heads=4, + sliding_window=1024, + rope_scaling={"rope_type": "linear", "factor": 8.0}, # used for global RoPE only + rope_theta=1_000_000, + rope_local_base_freq=10_000, + attn_logit_softcapping=None, + query_pre_attn_scalar=256, + max_position_embeddings=8192, + ), + vision_config={ + "hidden_size": 1152, + "intermediate_size": 4304, + "num_hidden_layers": 27, + "num_attention_heads": 16, + "num_channels": 3, + "image_size": 896, + "patch_size": 14, + "hidden_act": "gelu_pytorch_tanh", + "layer_norm_eps": 1e-6, + "attention_dropout": 0.0, + "vision_use_head": False, + }, + ) + + config.save_pretrained(output_path) + + image_processor = Gemma3ImageProcessor( + image_seq_length=256, + image_mean=(0.5,) * 3, + image_std=(0.5,) * 3, + size={"height": 896, "width": 896}, + resample=PILImageResampling.BILINEAR, + ) + processor = ShieldGemma2Processor( + image_processor=image_processor, + tokenizer=tokenizer, + policy_definitions=_SHIELDGEMMA2_POLICIES, + ) + tokenizer.chat_template = _CHAT_TEMPLATE + processor.chat_template = _CHAT_TEMPLATE + + processor.save_pretrained(output_path) + logging.info("Saved Shieldgemma2Processor to %s", output_path) + del processor + del tokenizer + + logging.info("Converting Shieldgemma2 @ %s", dtype) + result = convert(_SHIELDGEMMA_CHECKPOINT_PATH.value, _GEMMA_CHECKPOINT_PATH.value, config, dtype) + logging.info("Converted Shieldgemma2 state tree from Orbax to Hugging Face.") + + with accelerate.init_empty_weights(): + model = ShieldGemma2ForImageClassification(config=config) + + model.load_state_dict(result.state_tree, assign=True, strict=True) + model.config.torch_dtype = dtype + logging.info("Loaded Shieldgemma2 in Hugging Face Transformers.") + model.save_pretrained(output_path, safe_serialization=True) + logging.info("Saved Shieldgemma2 to SafeTensors in %s", output_path) + del model + del result + + +if __name__ == "__main__": + app.run(main) diff --git a/src/transformers/models/shieldgemma2/modeling_shieldgemma2.py b/src/transformers/models/shieldgemma2/modeling_shieldgemma2.py new file mode 100644 index 0000000000..47ebc3e1f6 --- /dev/null +++ b/src/transformers/models/shieldgemma2/modeling_shieldgemma2.py @@ -0,0 +1,228 @@ +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import List, Optional, Union + +import torch +import torch.utils.checkpoint + +from ...cache_utils import Cache +from ...modeling_outputs import ImageClassifierOutputWithNoAttention +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ...utils.deprecation import deprecate_kwarg +from ..auto import AutoModelForImageTextToText +from .configuration_shieldgemma2 import ShieldGemma2Config + + +_CHECKPOINT_FOR_DOC = "google/shieldgemma-2-4b-it" +_CONFIG_FOR_DOC = "ShieldGemma2Config" + +logger = logging.get_logger(__name__) + +SHIELDGEMMA2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. + + Returns: + A `ShieldGemma2ImageClassifierOutputWithNoAttention` instance continaing the logits and probabilities + associated with the model predicting the `Yes` or `No` token as the response to that prompt, captured in the + following properties. + + * `logits` (`torch.Tensor` of shape `(batch_size, 2)`): + The first position along dim=1 is the logits for the `Yes` token and the second position along dim=1 is + the logits for the `No` token. + * `probabilities` (`torch.Tensor` of shape `(batch_size, 2)`): + The first position along dim=1 is the probability of predicting the `Yes` token and the second position + along dim=1 is the probability of predicting the `No` token. + + ShieldGemma prompts are constructed such that predicting the `Yes` token means the content *does violate* the + policy as described. If you are only interested in the violative condition, use + `violated = outputs.probabilities[:, 1]` to extract that slice from the output tensors. + + When used with the `ShieldGemma2Processor`, the `batch_size` will be equal to `len(images) * len(policies)`, + and the order within the batch will be img1_policy1, ... img1_policyN, ... imgM_policyN. +""" + + +@dataclass +class ShieldGemma2ImageClassifierOutputWithNoAttention(ImageClassifierOutputWithNoAttention): + """ShieldGemma2 classifies imags as violative or not relative to a specific policy + Args: + """ + + probabilities: torch.Tensor = None + + +class ShieldGemma2ForImageClassification(PreTrainedModel): + config_class = ShieldGemma2Config + + def __init__(self, config: ShieldGemma2Config): + super().__init__(config=config) + self.yes_token_index = getattr(config, "yes_token_index", 10_784) + self.no_token_index = getattr(config, "no_token_index", 3771) + self.model = AutoModelForImageTextToText.from_config(config=config) + + def get_input_embeddings(self): + return self.model.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.model.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.model.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.model.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.model.language_model.get_decoder() + + def tie_weights(self): + return self.model.language_model.tie_weights() + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @add_start_docstrings_to_model_forward(SHIELDGEMMA2_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=ShieldGemma2ImageClassifierOutputWithNoAttention, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, + ) -> ShieldGemma2ImageClassifierOutputWithNoAttention: + """Predicts the binary probability that the image violates the speicfied policy. + + Returns: + """ + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + token_type_ids=token_type_ids, + cache_position=cache_position, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + logits_to_keep=logits_to_keep, + **lm_kwargs, + ) + logits = outputs.logits + selected_logits = logits[:, -1, [self.yes_token_index, self.no_token_index]] + probabilities = torch.softmax(selected_logits, dim=-1) + return ShieldGemma2ImageClassifierOutputWithNoAttention( + logits=selected_logits, + probabilities=probabilities, + ) + + +__all__ = [ + "ShieldGemma2ForImageClassification", +] diff --git a/src/transformers/models/shieldgemma2/processing_shieldgemma2.py b/src/transformers/models/shieldgemma2/processing_shieldgemma2.py new file mode 100644 index 0000000000..097b83ef6e --- /dev/null +++ b/src/transformers/models/shieldgemma2/processing_shieldgemma2.py @@ -0,0 +1,195 @@ +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Mapping, Sequence +from typing import Optional + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import Unpack +from ...utils import logging +from ..gemma3.processing_gemma3 import Gemma3Processor, Gemma3ProcessorKwargs + + +logger = logging.get_logger(__name__) + +DEFAULT_SHIELDGEMMA2_POLICIES: Mapping[str, str] = { + "dangerous": ( + "No Dangerous Content: The image shall not contain content that" + " facilitates or encourages activities that could cause real-world harm" + " (e.g., building firearms and explosive devices, promotion of" + " terrorism, instructions for suicide)." + ), + "sexual": ( + "No Sexually Explicit content: The image shall not contain content that" + " depicts explicit or graphic sexual acts (e.g., pornography, erotic" + " nudity, depictions of rape or sexual assault)." + ), + "violence": ( + "No Violence/Gore content: The image shall not contain content that" + " depicts shocking, sensational, or gratuitous violence (e.g.," + " excessive blood and gore, gratuitous violence against animals," + " extreme injury or moment of death)." + ), +} + + +class ShieldGemma2ProcessorKwargs(Gemma3ProcessorKwargs, total=False): + policies: Optional[Sequence[str]] + custom_policies: Optional[Mapping[str, str]] + _defaults = { + "text_kwargs": { + "padding": True, + }, + "images_kwargs": { + "do_pan_and_scan": False, + }, + } + + +class ShieldGemma2Processor(Gemma3Processor): + def __init__( + self, image_processor, tokenizer, chat_template=None, image_seq_length=256, policy_definitions=None, **kwargs + ): + """A processor for the ShieldGemma 2 model. + + Args: + image_processor: The image processor to use, typically a `Gemma3ImageProcessorFast` instance. + tokenizer: The tokenizer to use, typically a `GemmaTokenizerFast` instance. + chat_template: The chat template to use with this processor. Typically, this is unset as the processor + configuration on Hugging Face Hub includes this value already. + image_seq_length: The number of soft tokens per image. Typically, this is unset as the processor + configuration on Hugging Face Hub includes this value already. + policy_definitions: A mapping from policy name to its description in text used as the default policies to + classify images against. The policy descriptions are included in the text of the prompts generated by + this processor. Typically, this is unset as the processor configuration on Hugging Face Hub includes + the base policies ShieldGemma was trained on. + """ + super().__init__(image_processor, tokenizer, chat_template, image_seq_length, **kwargs) + if policy_definitions is None: + self.policy_definitions = DEFAULT_SHIELDGEMMA2_POLICIES + else: + self.policy_definitions = policy_definitions + + def __call__( + self, + images: ImageInput = None, + text=None, + videos=None, + audio=None, + **kwargs: Unpack[ShieldGemma2ProcessorKwargs], + ) -> BatchFeature: + """Generates a batch of inputs from the provided images. + + ShieldGemma was trained to classify image content for policy compliance using a specific prompt construction. + This processor generates a batch of such prompts from the provided images by: + + 1. Creating a list of conversations, one for each `` pair; + 2. Converting these conversations to text using `self.apply_chat_template()`; and + 3. Encoding the conversations and images using the same techniques as `Gemma3Processor`. + + Args: + images: A single image or a list of images to include in the batch. + text: Not supported. + videos: Not supported. + audio: Not supported. + kwargs: An optional dictionary of keyword arguments to configre the + processor. Possible values include: + + * `custom_policies`: Additional policy definitions that augment the `self.policy_definitions` passed + into the constructor. Note that `custom_policies` that share a key with `self.policy_definitions` + will override the policy description + * `policies`: (Optional) a list of keys in the joint `self.policy_definitions | custom_policies` + dictionary of specific interest for the provided images. If empty or None, prompts will be + generated for every key in the joint dictionary. + + Returns: + A `BatchFeature` continaing `input_ids`, `pixel_values`, etc. where each Tensor is of shape + `(len(images) * len(policies), )`, and the order within the batch will be + img1_policy1, ... img1_policyN, ... imgM_policyN. + """ + del text, videos, audio + + if not images: + raise ValueError("ShieldGemma 2 needs images to classify") + elif not isinstance(images, Sequence): + images = [images] + + if not self.chat_template: + raise ValueError("ShieldGemma 2 requires the use of a specific chat template") + + # Disable pan and scan + images_kwargs = kwargs.setdefault("images_kwargs", {}) + if images_kwargs.get("do_pan_and_scan") is True: + logger.warning_once("ShieldGemma2 does not support pan and scan.") + images_kwargs["do_pan_and_scan"] = False + + # Enable padding on the batch during tokenization + text_kwargs = kwargs.setdefault("text_kwargs", {}) + if "padding" not in text_kwargs: + text_kwargs["padding"] = kwargs.pop("padding", True) + text_kwargs["padding_side"] = kwargs.pop("padding_side", "left") + + policy_definitions: Mapping[str, str] = { + **self.policy_definitions, + **kwargs.get("custom_policies", {}), + } + + if (policies := kwargs.get("policies")) is None: + policies = list(policy_definitions.keys()) + + # TODO(ryanmullins): Support images from PIL or URLs. + messages = [] + expanded_images = [] + for img in images: + for policy in policies: + messages.append( + [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": policy_definitions[policy]}, + ], + } + ] + ) + expanded_images.append([img]) + + text = self.apply_chat_template(messages, tokenize=False) + return super().__call__(images=expanded_images, text=text, **kwargs) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"] + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +__all__ = ["ShieldGemma2Processor"] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index ceca4158f5..daef535432 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -8870,6 +8870,13 @@ class SEWDPreTrainedModel(metaclass=DummyObject): requires_backends(self, ["torch"]) +class ShieldGemma2ForImageClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class SiglipForImageClassification(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/shieldgemma2/__init__.py b/tests/models/shieldgemma2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/shieldgemma2/test_modeling_shieldgemma2.py b/tests/models/shieldgemma2/test_modeling_shieldgemma2.py new file mode 100644 index 0000000000..fdc5d9e713 --- /dev/null +++ b/tests/models/shieldgemma2/test_modeling_shieldgemma2.py @@ -0,0 +1,61 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Gemma3 model.""" + +import unittest +from io import BytesIO + +import requests +from PIL import Image + +from transformers import is_torch_available +from transformers.testing_utils import ( + cleanup, + require_torch_gpu, + slow, + torch_device, +) + + +if is_torch_available(): + import torch + + from transformers import ShieldGemma2ForImageClassification, ShieldGemma2Processor + + +@slow +@require_torch_gpu +# @require_read_token +class ShieldGemma2IntegrationTest(unittest.TestCase): + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + def test_model(self): + model_id = "google/shieldgemma-2-4b-it" + + processor = ShieldGemma2Processor.from_pretrained(model_id, padding_side="left") + url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png" + response = requests.get(url) + image = Image.open(BytesIO(response.content)) + + model = ShieldGemma2ForImageClassification.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 + ).to(torch_device) + + inputs = processor(images=[image]).to(torch_device) + output = model(**inputs) + self.assertEqual(len(output.probabilities), 3) + for element in output.probabilities: + self.assertEqual(len(element), 2) diff --git a/tests/models/shieldgemma2/test_processing_shieldgemma2.py b/tests/models/shieldgemma2/test_processing_shieldgemma2.py new file mode 100644 index 0000000000..31ae324870 --- /dev/null +++ b/tests/models/shieldgemma2/test_processing_shieldgemma2.py @@ -0,0 +1,220 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import shutil +import tempfile +import unittest +from collections.abc import Mapping + +from parameterized import parameterized + +from transformers import GemmaTokenizer, ShieldGemma2Processor +from transformers.testing_utils import get_tests_dir, require_vision +from transformers.utils import is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_vision_available(): + from transformers import Gemma3ImageProcessor + +SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") + +# Copied from _CHAT_TEMPLATE in src/transformers/models/shieldgemma2/convert_shieldgemma2_weights_orbax_to_hf.py +_CHAT_TEMPLATE = """{{ bos_token }} +{%- for message in messages -%} + {%- if message['content'] is iterable -%} + {%- for item in message['content'] -%} + {{ 'user\n' }} + {%- if item['type'] == 'image' -%} + {{ '' }} + {%- elif item['type'] == 'text' -%} + {{ 'You are a policy expert trying to help determine whether a generated or real image is in' }} + {{ ' violation of the safety policy below. Your decision must very clearly follow the specific'}} + {{ ' language used in the safety policy and should not include any interpretation outside of the' }} + {{ ' safety policy.\n\n' }} + {{ '\n' }} + {{ item['text'] + '\n' }} + {{ '\n\n' }} + {{ 'Does the image violate the above policy? Your Answer must start with \\'Yes\\' or \\'No\\'.' }} + {{ '\n' }} + {%- endif -%} + {%- endfor -%} + {{'model\n'}} + {%- else -%} + {{ raise_exception("Conversation messages must contain iterable content containing images and policy definitions in text.") }} + {%- endif -%} +{%- endfor -%} +""" + +# Simplified from _SHIELDGEMMA2_POLICIES in src/transformers/models/shieldgemma2/convert_shieldgemma2_weights_orbax_to_hf.py +_SHIELDGEMMA2_POLICIES: Mapping[str, str] = { + "dangerous": "Test policy related to dangerous content.", + "sexual": "Test policy related to sexually explicit content.", + "violence": "Test policy related to violent content.", +} + + +@require_vision +class ShieldGemma2ProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = ShieldGemma2Processor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + image_processor = Gemma3ImageProcessor.from_pretrained("google/siglip-so400m-patch14-384") + + extra_special_tokens = { + "image_token": "", + "boi_token": "", + "eoi_token": "", + } + tokenizer = GemmaTokenizer(SAMPLE_VOCAB, keep_accents=True, extra_special_tokens=extra_special_tokens) + + processor_kwargs = self.prepare_processor_dict() + processor = ShieldGemma2Processor(image_processor=image_processor, tokenizer=tokenizer, **processor_kwargs) + processor.save_pretrained(self.tmpdirname) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def prepare_processor_dict(self): + return { + "chat_template": _CHAT_TEMPLATE, + "policy_definitions": _SHIELDGEMMA2_POLICIES, + } + + def test_policy_definitions_saved_in_config(self): + processor_config_path = os.path.join(self.tmpdirname, "processor_config.json") + + with open(processor_config_path, "rb") as processor_config_file: + json_dict = json.load(processor_config_file) + + self.assertIsInstance(json_dict, dict) + self.assertIn("policy_definitions", json_dict) + self.assertIs(len(json_dict["policy_definitions"]), 3) + + @parameterized.expand( + [ + ("all_policies", None, 3), + ("selected_policies", ["dangerous", "violence"], 2), + ("single_policy", ["sexual"], 1), + ] + ) + def test_with_default_policies(self, name, policies, expected_batch_size): + processor = self.get_processor() + + if processor.chat_template is None: + self.skipTest("Processor has no chat template") + + images = self.prepare_image_inputs() + processed_inputs = processor(images=images, policies=policies) + self.assertEqual(len(processed_inputs[self.text_input_name]), expected_batch_size) + self.assertEqual(len(processed_inputs[self.images_input_name]), expected_batch_size) + + @parameterized.expand( + [ + ("all_policies", None, 6), + ("selected_policies_from_both", ["cbrne", "dangerous", "specialized_advice", "violence"], 4), + ("selected_policies_from_custom", ["cbrne", "specialized_advice"], 2), + ("selected_policies_from_default", ["dangerous", "violence"], 2), + ("single_policy_from_custom", ["ip"], 1), + ("single_policy_from_default", ["sexual"], 1), + ] + ) + def test_with_custom_policies(self, name, policies, expected_batch_size): + processor = self.get_processor() + + if processor.chat_template is None: + self.skipTest("Processor has no chat template") + + # Test policies adapated from https://ailuminate.mlcommons.org/benchmarks/ hazard categories + custom_policies = { + "cbrne": "Test policy related to indiscriminate weapons.", + "ip": "Test policy related to intellectual property.", + "specialized_advice": "Test policy related to specialized advice.", + } + + images = self.prepare_image_inputs() + processed_inputs = processor(images=images, custom_policies=custom_policies, policies=policies) + self.assertEqual(len(processed_inputs[self.text_input_name]), expected_batch_size) + self.assertEqual(len(processed_inputs[self.images_input_name]), expected_batch_size) + + def test_with_multiple_images(self): + processor = self.get_processor() + + if processor.chat_template is None: + self.skipTest("Processor has no chat template") + + images = self.prepare_image_inputs(batch_size=2) + print(images) + processed_inputs = processor(images=images) + self.assertEqual(len(processed_inputs[self.text_input_name]), 6) + self.assertEqual(len(processed_inputs[self.images_input_name]), 6) + + # TODO(ryanmullins): Adapt this test for ShieldGemma 2 + @unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.") + def test_chat_template_accepts_processing_kwargs(self): + pass + + # TODO(ryanmullins): Adapt this test for ShieldGemma 2 + @unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.") + def test_chat_template_batched(self): + pass + + # TODO(ryanmullins): Adapt this test for ShieldGemma 2 + @unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.") + def test_chat_template_dict_torch(self): + pass + + # TODO(ryanmullins): Adapt this test for ShieldGemma 2 + @unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.") + def test_chat_template_single(self): + pass + + # TODO(ryanmullins): Adapt this test for ShieldGemma 2 + @unittest.skip("Parent test needs to be adapted for ShieldGemma 2.") + def test_unstructured_kwargs_batched(self): + pass + + # TODO(ryanmullins): Adapt this test for ShieldGemma 2 + @unittest.skip("Parent test needs to be adapted for ShieldGemma 2.") + def test_unstructured_kwargs(self): + pass + + # TODO(ryanmullins): Adapt this test for ShieldGemma 2 + @unittest.skip("Parent test needs to be adapted for ShieldGemma 2.") + def test_tokenizer_defaults_preserved_by_kwargs(self): + pass + + # TODO(ryanmullins): Adapt this test for ShieldGemma 2 + @unittest.skip("Parent test needs to be adapted for ShieldGemma 2.") + def test_structured_kwargs_nested_from_dict(self): + pass + + # TODO(ryanmullins): Adapt this test for ShieldGemma 2 + @unittest.skip("Parent test needs to be adapted for ShieldGemma 2.") + def test_structured_kwargs_nested(self): + pass + + # TODO(ryanmullins): Adapt this test for ShieldGemma 2 + @unittest.skip("Parent test needs to be adapted for ShieldGemma 2.") + def test_kwargs_overrides_default_tokenizer_kwargs(self): + pass + + # TODO(ryanmullins): Adapt this test for ShieldGemma 2 + @unittest.skip("Parent test needs to be adapted for ShieldGemma 2.") + def test_kwargs_overrides_default_image_processor_kwargs(self): + pass diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 507046ea3c..7ce1bdd0f8 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -228,6 +228,14 @@ SPECIAL_CASES_TO_ALLOW = { "GPTNeoXConfig": ["rotary_emb_base"], "Gemma3Config": ["boi_token_index", "eoi_token_index"], "Gemma3TextConfig": ["cache_implementation", "tie_word_embeddings"], + "ShieldGemma2Config": [ + "boi_token_index", + "eoi_token_index", + "initializer_range", + "mm_tokens_per_image", + "text_config", + "vision_config", + ], } diff --git a/utils/check_repo.py b/utils/check_repo.py index 3b3dddf9cf..54bb9267c5 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -167,6 +167,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [ "models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py", "models/decision_transformer/test_modeling_decision_transformer.py", "models/bark/test_modeling_bark.py", + "models/shieldgemma2/test_modeling_shieldgemma2.py", ] # Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and