Shieldgemma2 (#36678)

* single commit

* correct config

* fixup

* dummy pt

* Use ShieldGemma2Config in conversion script

* Update src/transformers/models/shieldgemma2/configuration_shieldgemma2.py

* Adding shieldgemma2 to models.__init__.py

* Adding ShieldGemma2 to main __init__.py

* Update shieldgemma2.md

* Update shieldgemma2.md

* Adding tests. Addressing review feedback.

* Minor docs update

* Fixing code quality feedback from CI

* Fixing empty messages bug reported by ghunkins

---------

Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
Co-authored-by: Ren Pang <ain-soph@live.com>
This commit is contained in:
Ryan Mullins
2025-03-20 10:14:38 -04:00
committed by GitHub
parent a63e92e2f0
commit 487dab1b2b
19 changed files with 1459 additions and 0 deletions

View File

@@ -0,0 +1,100 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# ShieldGemma 2
## Overview
The ShieldGemma 2 model was proposed in a forthcoming technical report by Google. ShieldGemma 2 is built on [Gemma 3](https://ai.google.dev/gemma/docs/core/model_card_3), is a 4 billion (4B) parameter model that checks the safety of both synthetic and natural images against key categories to help you build robust datasets and models. With this addition to the Gemma family of models, researchers and developers can now easily minimize the risk of harmful content in their models across key areas of harm as defined below:
- No Sexually Explicit content: The image shall not contain content that depicts explicit or graphic sexual acts (e.g., pornography, erotic nudity, depictions of rape or sexual assault).
- No Dangerous Content: The image shall not contain content that facilitates or encourages activities that could cause real-world harm (e.g., building firearms and explosive devices, promotion of terrorism, instructions for suicide).
- No Violence/Gore content: The image shall not contain content that depicts shocking, sensational, or gratuitous violence (e.g., excessive blood and gore, gratuitous violence against animals, extreme injury or moment of death).
We recommend using ShieldGemma 2 as an input filter to vision language models, or as an output filter of image generation systems. To train a robust image safety model, we curated training datasets of natural and synthetic images and instruction-tuned Gemma 3 to demonstrate strong performance.
This model was contributed by [Ryan Mullins](https://huggingface.co/RyanMullins).
## Usage Example
- ShieldGemma 2 provides a Processor that accepts a list of `images` and an optional list of `policies` as input, and constructs a batch of prompts as the product of these two lists using the provided chat template.
- You can extend ShieldGemma's built-in in policies with the `custom_policies` argument to the Processor. Using the same key as one of the built-in policies will overwrite that policy with your custom defintion.
- ShieldGemma 2 does not support the image cropping capabilities used by Gemma 3.
### Classification against Built-in Policies
```python
from PIL import Image
import requests
from transformers import AutoProcessor, ShieldGemma2ForImageClassification
model_id = "google/shieldgemma-2-4b-it"
model = ShieldGemma2ForImageClassification.from_pretrained(model_id, device_map="auto")
processor = AutoProcessor.from_pretrained(model_id)
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(images=[image], return_tensors="pt").to(model.device)
output = model(**inputs)
print(output.probabilities)
```
### Classification against Custom Policies
```python
from PIL import Image
import requests
from transformers import AutoProcessor, ShieldGemma2ForImageClassification
model_id = "google/shieldgemma-2-4b-it"
model = ShieldGemma2ForImageClassification.from_pretrained(model_id, device_map="auto")
processor = AutoProcessor.from_pretrained(model_id)
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image = Image.open(requests.get(url, stream=True).raw)
custom_policies = {
"key_a": "descrition_a",
"key_b": "descrition_b",
}
inputs = processor(
images=[image],
custom_policies=custom_policies,
policies=["dangerous", "key_a", "key_b"],
return_tensors="pt",
).to(model.device)
output = model(**inputs)
print(output.probabilities)
```
## ShieldGemma2Processor
[[autodoc]] ShieldGemma2Processor
## ShieldGemma2Config
[[autodoc]] ShieldGemma2Config
## ShieldGemma2ForImageClassification
[[autodoc]] ShieldGemma2ForImageClassification
- forward

View File

@@ -774,6 +774,10 @@ _import_structure = {
"models.seggpt": ["SegGptConfig"], "models.seggpt": ["SegGptConfig"],
"models.sew": ["SEWConfig"], "models.sew": ["SEWConfig"],
"models.sew_d": ["SEWDConfig"], "models.sew_d": ["SEWDConfig"],
"models.shieldgemma2": [
"ShieldGemma2Config",
"ShieldGemma2Processor",
],
"models.siglip": [ "models.siglip": [
"SiglipConfig", "SiglipConfig",
"SiglipProcessor", "SiglipProcessor",
@@ -3581,6 +3585,7 @@ else:
"SEWDPreTrainedModel", "SEWDPreTrainedModel",
] ]
) )
_import_structure["models.shieldgemma2"].append("ShieldGemma2ForImageClassification")
_import_structure["models.siglip"].extend( _import_structure["models.siglip"].extend(
[ [
"SiglipForImageClassification", "SiglipForImageClassification",
@@ -5982,6 +5987,10 @@ if TYPE_CHECKING:
from .models.seggpt import SegGptConfig from .models.seggpt import SegGptConfig
from .models.sew import SEWConfig from .models.sew import SEWConfig
from .models.sew_d import SEWDConfig from .models.sew_d import SEWDConfig
from .models.shieldgemma2 import (
ShieldGemma2Config,
ShieldGemma2Processor,
)
from .models.siglip import ( from .models.siglip import (
SiglipConfig, SiglipConfig,
SiglipProcessor, SiglipProcessor,
@@ -8350,6 +8359,9 @@ if TYPE_CHECKING:
SEWDModel, SEWDModel,
SEWDPreTrainedModel, SEWDPreTrainedModel,
) )
from .models.shieldgemma2 import (
ShieldGemma2ForImageClassification,
)
from .models.siglip import ( from .models.siglip import (
SiglipForImageClassification, SiglipForImageClassification,
SiglipModel, SiglipModel,

View File

@@ -247,6 +247,7 @@ from . import (
seggpt, seggpt,
sew, sew,
sew_d, sew_d,
shieldgemma2,
siglip, siglip,
siglip2, siglip2,
smolvlm, smolvlm,

View File

@@ -274,6 +274,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("seggpt", "SegGptConfig"), ("seggpt", "SegGptConfig"),
("sew", "SEWConfig"), ("sew", "SEWConfig"),
("sew-d", "SEWDConfig"), ("sew-d", "SEWDConfig"),
("shieldgemma2", "ShieldGemma2Config"),
("siglip", "SiglipConfig"), ("siglip", "SiglipConfig"),
("siglip2", "Siglip2Config"), ("siglip2", "Siglip2Config"),
("siglip_vision_model", "SiglipVisionConfig"), ("siglip_vision_model", "SiglipVisionConfig"),
@@ -625,6 +626,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("seggpt", "SegGPT"), ("seggpt", "SegGPT"),
("sew", "SEW"), ("sew", "SEW"),
("sew-d", "SEW-D"), ("sew-d", "SEW-D"),
("shieldgemma2", "Shieldgemma2"),
("siglip", "SigLIP"), ("siglip", "SigLIP"),
("siglip2", "SigLIP2"), ("siglip2", "SigLIP2"),
("siglip2_vision_model", "Siglip2VisionModel"), ("siglip2_vision_model", "Siglip2VisionModel"),

View File

@@ -137,6 +137,7 @@ else:
("sam", ("SamImageProcessor",)), ("sam", ("SamImageProcessor",)),
("segformer", ("SegformerImageProcessor",)), ("segformer", ("SegformerImageProcessor",)),
("seggpt", ("SegGptImageProcessor",)), ("seggpt", ("SegGptImageProcessor",)),
("shieldgemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")), ("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
("siglip2", ("Siglip2ImageProcessor", "Siglip2ImageProcessorFast")), ("siglip2", ("Siglip2ImageProcessor", "Siglip2ImageProcessorFast")),
("superglue", "SuperGlueImageProcessor"), ("superglue", "SuperGlueImageProcessor"),

View File

@@ -727,6 +727,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("regnet", "RegNetForImageClassification"), ("regnet", "RegNetForImageClassification"),
("resnet", "ResNetForImageClassification"), ("resnet", "ResNetForImageClassification"),
("segformer", "SegformerForImageClassification"), ("segformer", "SegformerForImageClassification"),
("shieldgemma2", "ShieldGemma2ForImageClassification"),
("siglip", "SiglipForImageClassification"), ("siglip", "SiglipForImageClassification"),
("siglip2", "Siglip2ForImageClassification"), ("siglip2", "Siglip2ForImageClassification"),
("swiftformer", "SwiftFormerForImageClassification"), ("swiftformer", "SwiftFormerForImageClassification"),
@@ -849,6 +850,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
("pixtral", "LlavaForConditionalGeneration"), ("pixtral", "LlavaForConditionalGeneration"),
("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),
("qwen2_vl", "Qwen2VLForConditionalGeneration"), ("qwen2_vl", "Qwen2VLForConditionalGeneration"),
("shieldgemma2", "Gemma3ForConditionalGeneration"),
("smolvlm", "SmolVLMForConditionalGeneration"), ("smolvlm", "SmolVLMForConditionalGeneration"),
("udop", "UdopForConditionalGeneration"), ("udop", "UdopForConditionalGeneration"),
("vipllava", "VipLlavaForConditionalGeneration"), ("vipllava", "VipLlavaForConditionalGeneration"),

View File

@@ -101,6 +101,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("seamless_m4t", "SeamlessM4TProcessor"), ("seamless_m4t", "SeamlessM4TProcessor"),
("sew", "Wav2Vec2Processor"), ("sew", "Wav2Vec2Processor"),
("sew-d", "Wav2Vec2Processor"), ("sew-d", "Wav2Vec2Processor"),
("shieldgemma2", "ShieldGemma2Processor"),
("siglip", "SiglipProcessor"), ("siglip", "SiglipProcessor"),
("siglip2", "Siglip2Processor"), ("siglip2", "Siglip2Processor"),
("speech_to_text", "Speech2TextProcessor"), ("speech_to_text", "Speech2TextProcessor"),

View File

@@ -493,6 +493,13 @@ else:
"SeamlessM4TTokenizerFast" if is_tokenizers_available() else None, "SeamlessM4TTokenizerFast" if is_tokenizers_available() else None,
), ),
), ),
(
"shieldgemma2",
(
"GemmaTokenizer" if is_sentencepiece_available() else None,
"GemmaTokenizerFast" if is_tokenizers_available() else None,
),
),
("siglip", ("SiglipTokenizer" if is_sentencepiece_available() else None, None)), ("siglip", ("SiglipTokenizer" if is_sentencepiece_available() else None, None)),
( (
"siglip2", "siglip2",

View File

@@ -0,0 +1,28 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_shieldgemma2 import *
from .modeling_shieldgemma2 import *
from .processing_shieldgemma2 import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@@ -0,0 +1,115 @@
# coding=utf-8
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING, AutoConfig
logger = logging.get_logger(__name__)
class ShieldGemma2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`ShieldGemma2ForImageClassification`]. It is used to instantiate an
ShieldGemma2ForImageClassification according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the shieldgemma-2-4b-it.
e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`Union[ShieldGemma2TextConfig, dict]`, *optional*):
The config object of the text backbone.
vision_config (`Union[AutoConfig, dict]`, *optional*):
Custom vision config or dict.
mm_tokens_per_image (`int`, *optional*, defaults to 256):
The number of tokens per image embedding.
boi_token_index (`int`, *optional*, defaults to 255999):
The begin-of-image token index to wrap the image prompt.
eoi_token_index (`int`, *optional*, defaults to 256000):
The end-of-image token index to wrap the image prompt.
image_token_index (`int`, *optional*, defaults to 262144):
The image token index to encode the image prompt.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
Example:
```python
>>> from transformers import ShieldGemma2ForConditionalGeneration, ShieldGemma2Config, SiglipVisionConfig, ShieldGemma2TextConfig
>>> # Initializing a Siglip-like vision config
>>> vision_config = SiglipVisionConfig()
>>> # Initializing a ShieldGemma2 Text config
>>> text_config = ShieldGemma2TextConfig()
>>> # Initializing a ShieldGemma2 gemma-3-4b style configuration
>>> configuration = ShieldGemma2Config(vision_config, text_config)
>>> # Initializing a model from the gemma-3-4b style configuration
>>> model = ShieldGemma2TextConfig(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "shieldgemma2"
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
def __init__(
self,
text_config=None,
vision_config=None,
mm_tokens_per_image: int = 256,
boi_token_index: int = 255_999,
eoi_token_index: int = 256_000,
image_token_index: int = 262_144,
initializer_range: float = 0.02,
**kwargs,
):
if isinstance(vision_config, dict):
vision_config["model_type"] = (
vision_config["model_type"] if "model_type" in vision_config else "siglip_vision_model"
)
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
elif vision_config is None:
vision_config = CONFIG_MAPPING["siglip_vision_model"]()
self.vision_config = vision_config
if isinstance(text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma3_text"
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
text_config = CONFIG_MAPPING["gemma3_text"]()
self.text_config = text_config
self.vision_config = vision_config
self.mm_tokens_per_image = mm_tokens_per_image
self.boi_token_index = boi_token_index
self.eoi_token_index = eoi_token_index
self.image_token_index = image_token_index
self.initializer_range = initializer_range
super().__init__(**kwargs)
__all__ = ["ShieldGemma2Config"]

View File

@@ -0,0 +1,470 @@
r"""Utility to convert Gemma models from Orbax to HF Transformers checkpoint.
python -m transformers.models.shieldgemma2.convert_shieldgemma2_weights_orbax_to_hf \
--tokenizer_path="$HOME/gemma3/tokenizer/gemma3_cleaned_262144_v2.spiece.model" \
--checkpoint_path_gemma="$HOME/gemma3/gemma3_4b_pt_orbax/" \
--checkpoint_path_shieldgemma="$HOME/shieldgemma2/shieldgemma-2_4b_orbax/" \
--output_path="$HOME/shieldgemma2/shieldgemma2_4b_pt_safetensors/" \
--precision='bfloat16'
"""
import dataclasses
from collections.abc import Iterator, Mapping, Sequence
from typing import Any
import accelerate
import numpy as np
import torch
import tree
from absl import app, flags, logging
from orbax import checkpoint as obc
from ...image_utils import PILImageResampling
from ..gemma import GemmaTokenizerFast
from ..gemma3 import Gemma3ImageProcessor, Gemma3TextConfig
from ..siglip import SiglipVisionConfig
from . import ShieldGemma2Config, ShieldGemma2ForImageClassification, ShieldGemma2Processor
# ==== Internal Constants and Classes ====
_CHAT_TEMPLATE = """{{ bos_token }}
{%- for message in messages -%}
{%- if message['content'] is iterable -%}
{%- for item in message['content'] -%}
{{ '<start_of_turn>user\n' }}
{%- if item['type'] == 'image' -%}
{{ '<start_of_image>' }}
{%- elif item['type'] == 'text' -%}
{{ 'You are a policy expert trying to help determine whether a generated or real image is in' }}
{{ ' violation of the safety policy below. Your decision must very clearly follow the specific'}}
{{ ' language used in the safety policy and should not include any interpretation outside of the' }}
{{ ' safety policy.\n\n' }}
{{ '<BEGIN_SAFETY_POLICY>\n' }}
{{ item['text'] + '\n' }}
{{ '</END_SAFETY_POLICY>\n\n' }}
{{ 'Does the image violate the above policy? Your Answer must start with \\'Yes\\' or \\'No\\'.' }}
{{ '<end_of_turn>\n' }}
{%- endif -%}
{%- endfor -%}
{{'<start_of_turn>model\n'}}
{%- else -%}
{{ raise_exception("Conversation messages must contain iterable content containing images and policy definitions in text.") }}
{%- endif -%}
{%- endfor -%}
"""
_DTYPES = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}
_SHIELDGEMMA2_POLICIES: Mapping[str, str] = {
"dangerous": (
"No Dangerous Content: The image shall not contain content that"
" facilitates or encourages activities that could cause real-world harm"
" (e.g., building firearms and explosive devices, promotion of"
" terrorism, instructions for suicide)."
),
"sexual": (
"No Sexually Explicit content: The image shall not contain content that"
" depicts explicit or graphic sexual acts (e.g., pornography, erotic"
" nudity, depictions of rape or sexual assault)."
),
"violence": (
"No Violence/Gore content: The image shall not contain content that"
" depicts shocking, sensational, or gratuitous violence (e.g.,"
" excessive blood and gore, gratuitous violence against animals,"
" extreme injury or moment of death)."
),
}
_SIGLIP_BASE = "SigLiPFromPatches_0/siglip_encoder"
_SIGLIP_EMBEDDING = "SigLiPFromPatches_0/siglip_encoder/embedding"
_SIGLIP_TRANSFORMER_ENCODER_BLOCK = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoderblock_"
_SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN = len(_SIGLIP_TRANSFORMER_ENCODER_BLOCK)
_SIGLIP_TRANSFORMER_ENCODER_NORM = "SigLiPFromPatches_0/siglip_encoder/Transformer/encoder_norm"
_TRANSFORMER_DECODER_BLOCK = "transformer/layer_"
_TRANSFORMER_DECODER_BLOCK_LEN = len(_TRANSFORMER_DECODER_BLOCK)
_TRANSFORMER_EMBEDDER = "transformer/embedder"
_TRANSFORMER_FINAL_NORM = "transformer/final_norm"
_TRANSFORMER_POST_TRAINING_PREFIX = "rlx_networks/policy_network/"
_TRANSFORMER_POST_TRAINING_PREFIX_LEN = len(_TRANSFORMER_POST_TRAINING_PREFIX)
# ==== Flags ====
_GEMMA_CHECKPOINT_PATH = flags.DEFINE_string(
name="checkpoint_path_gemma",
default=None,
help="Path to the Orbax checkpoint containing the vision weights.",
required=True,
)
_SHIELDGEMMA_CHECKPOINT_PATH = flags.DEFINE_string(
name="checkpoint_path_shieldgemma",
default=None,
help="Path to the Orbax checkpoint containing the language model weights.",
required=True,
)
OUTPUT_PATH = flags.DEFINE_string(
name="output_path",
default=None,
help="Path to store the HF checkpoint.",
required=True,
)
PRECISION = flags.DEFINE_enum(
name="precision",
default=None,
help="The floating point precision (aka dtype) of the model.",
enum_values=set(_DTYPES.keys()),
required=True,
)
TOKENIZER_PATH = flags.DEFINE_string(
name="tokenizer_path",
default=None,
help="Path to the SentencePiece model file.",
required=True,
)
def convert_siglip_weight(
config: SiglipVisionConfig,
paths: Sequence[str],
weights: np.ndarray,
) -> tuple[str, np.ndarray]:
path, prop = paths
normalized_path: str = ""
updated_weights: np.ndarray = None
if path == _SIGLIP_BASE:
normalized_path = "vision_tower.vision_model.embeddings.position_embedding.weight"
updated_weights = weights.reshape(-1, config.hidden_size)
elif path == _SIGLIP_EMBEDDING:
if prop == "kernel":
normalized_path = "vision_tower.vision_model.embeddings.patch_embedding.weight"
updated_weights = weights.transpose(3, 2, 0, 1)
elif prop == "bias":
normalized_path = "vision_tower.vision_model.embeddings.patch_embedding.bias"
updated_weights = weights
else:
raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.")
elif path.startswith(_SIGLIP_TRANSFORMER_ENCODER_BLOCK):
encoder_block_path = path[_SIGLIP_TRANSFORMER_ENCODER_BLOCK_LEN:]
next_path_seperator_idx = encoder_block_path.find("/")
layer_idx = encoder_block_path[:next_path_seperator_idx]
encoder_block_path = encoder_block_path[next_path_seperator_idx:]
normalized_path = f"vision_tower.vision_model.encoder.layers.{layer_idx}"
if encoder_block_path.startswith("/LayerNorm"):
normalized_path += ".layer_norm1" if path.endswith("_0") else ".layer_norm2"
if prop == "scale":
normalized_path += ".weight"
updated_weights = weights.transpose()
elif prop == "bias":
normalized_path += ".bias"
updated_weights = weights
else:
raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `scale`.")
elif encoder_block_path.startswith("/MlpBlock_0"):
normalized_path += ".mlp.fc1" if "/Dense_0" in encoder_block_path else ".mlp.fc2"
if prop == "kernel":
normalized_path += ".weight"
updated_weights = weights.transpose()
elif prop == "bias":
normalized_path += ".bias"
updated_weights = weights
else:
raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.")
elif encoder_block_path.startswith("/MultiHeadDotProductAttention_0"):
if encoder_block_path.endswith("/key"):
normalized_path += ".self_attn.k_proj"
elif encoder_block_path.endswith("/out"):
normalized_path += ".self_attn.out_proj"
elif encoder_block_path.endswith("/query"):
normalized_path += ".self_attn.q_proj"
elif encoder_block_path.endswith("/value"):
normalized_path += ".self_attn.v_proj"
else:
raise ValueError(f"Unexpected path `{path}` in SigLIP Transformer MultiHeadDotProductAttention_0.")
if prop == "bias":
normalized_path += ".bias"
updated_weights = weights.reshape(-1, config.hidden_size).reshape(-1)
elif prop == "kernel":
normalized_path += ".weight"
updated_weights = weights.reshape(-1, config.hidden_size).transpose()
else:
raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `kernel`.")
else:
raise ValueError(f"Unexpected path `{path}` in SigLIP Transformer Encoder Block.")
elif path == _SIGLIP_TRANSFORMER_ENCODER_NORM:
if prop == "scale":
normalized_path = "vision_tower.vision_model.post_layernorm.weight"
updated_weights = weights.transpose()
elif prop == "bias":
normalized_path = "vision_tower.vision_model.post_layernorm.bias"
updated_weights = weights
else:
raise ValueError(f"Unexpected member, `{prop}`, for path `{path}`. Should be `bias` or `scale`.")
else:
raise ValueError(f"Unexpected path `{path}`.")
if "vision" in normalized_path:
print(normalized_path)
return normalized_path, updated_weights
def convert_transformer_weights(
config: Gemma3TextConfig,
paths: Sequence[str],
weights: np.ndarray,
) -> Iterator[tuple[str, np.ndarray]]:
path, prop = paths
if path.startswith(_TRANSFORMER_POST_TRAINING_PREFIX):
path = path[_TRANSFORMER_POST_TRAINING_PREFIX_LEN:]
converted_paths: list[str] = []
converted_weights: list[Any] = []
attn_head_dim = config.num_attention_heads * config.head_dim
kv_head_dim = config.num_key_value_heads * config.head_dim
if path == _TRANSFORMER_EMBEDDER:
if prop == "input_embedding":
# Tied to language_model.lm_head.weight, assigned at the end.
converted_paths = ["language_model.model.embed_tokens.weight"]
# Gemma3 model doesn't have image soft token in input and output embeddings, resize to avoid bugs we had with Mllama
pre_expansion_embeddings = weights
mu = np.mean(pre_expansion_embeddings, axis=0)
sigma = np.cov(pre_expansion_embeddings, rowvar=False, bias=True)
new_embeddings = np.random.multivariate_normal(mu, 1e-5 * sigma, size=64)
weights = np.vstack([pre_expansion_embeddings, new_embeddings])
converted_weights = [weights]
else:
raise ValueError(f"Unexpected member, {prop}, in Embedder.")
elif path.startswith(f"{_TRANSFORMER_EMBEDDER}/mm"):
if path.endswith("/mm_input_projection"):
converted_paths = ["multi_modal_projector.mm_input_projection_weight"]
converted_weights = [weights]
elif path.endswith("/mm_soft_embedding_norm"):
converted_paths = ["multi_modal_projector.mm_soft_emb_norm.weight"]
converted_weights = [weights]
else:
raise ValueError(f"Unexpected subpath, `{path}`, in Embedder.")
elif path == _TRANSFORMER_FINAL_NORM:
converted_paths = ["language_model.model.norm.weight"]
converted_weights = [weights]
elif path.startswith(_TRANSFORMER_DECODER_BLOCK):
decoder_block_path = path[_TRANSFORMER_DECODER_BLOCK_LEN:]
next_path_seperator_idx = decoder_block_path.find("/")
layer_idx = decoder_block_path[:next_path_seperator_idx]
decoder_block_path = decoder_block_path[next_path_seperator_idx:]
base_path = f"language_model.model.layers.{layer_idx}"
if path.endswith("attn/attn_vec_einsum"):
converted_paths = [f"{base_path}.self_attn.o_proj.weight"]
converted_weights = [weights.transpose(2, 0, 1).reshape(config.hidden_size, attn_head_dim)]
elif path.endswith("attn/_key_norm"):
converted_paths = [f"{base_path}.self_attn.k_norm.weight"]
converted_weights = [weights]
elif path.endswith("attn/kv_einsum"):
converted_paths = [
f"{base_path}.self_attn.k_proj.weight",
f"{base_path}.self_attn.v_proj.weight",
]
k_proj_weights, v_proj_weights = weights
converted_weights = [
k_proj_weights.transpose(0, 2, 1).reshape(kv_head_dim, config.hidden_size),
v_proj_weights.transpose(0, 2, 1).reshape(kv_head_dim, config.hidden_size),
]
elif path.endswith("attn/q_einsum"):
converted_paths = [f"{base_path}.self_attn.q_proj.weight"]
converted_weights = [weights.transpose(0, 2, 1).reshape(attn_head_dim, config.hidden_size)]
elif path.endswith("attn/_query_norm"):
converted_paths = [f"{base_path}.self_attn.q_norm.weight"]
converted_weights = [weights]
elif path.endswith("mlp/gating_einsum"):
converted_paths = [
f"{base_path}.mlp.gate_proj.weight",
f"{base_path}.mlp.up_proj.weight",
]
gate_proj_weight, up_proj_weight = weights
converted_weights = [gate_proj_weight, up_proj_weight]
elif path.endswith("mlp/linear"):
converted_paths = [f"{base_path}.mlp.down_proj.weight"]
converted_weights = [weights.transpose()]
elif path.endswith("post_attention_norm"):
converted_paths = [f"{base_path}.post_attention_layernorm.weight"]
converted_weights = [weights]
elif path.endswith("post_ffw_norm"):
converted_paths = [f"{base_path}.post_feedforward_layernorm.weight"]
converted_weights = [weights]
elif path.endswith("pre_attention_norm"):
converted_paths = [f"{base_path}.input_layernorm.weight"]
converted_weights = [weights]
elif path.endswith("pre_ffw_norm"):
converted_paths = [f"{base_path}.pre_feedforward_layernorm.weight"]
converted_weights = [weights]
else:
raise ValueError(f"Unexpected path `{path}` in Decoder Block.")
else:
raise ValueError(f"Unexpected path `{path}`.")
if (cpl := len(converted_paths)) != (cwl := len(converted_weights)):
raise ValueError(
"The `converted_paths` and `converted_weights` should be the same "
f"length. Got {cpl} and {cwl}, respectively, for {path}."
)
return zip(converted_paths, converted_weights)
def transpose_reshape(x: torch.Tensor) -> torch.Tensor:
x = x.transpose(1, 2)
return x.reshape(x.shape[0] * x.shape[1], x.shape[2]).contiguous()
@dataclasses.dataclass(frozen=True)
class ConversionResult:
state_tree: dict[str, torch.Tensor]
config: ShieldGemma2Config
def convert(
shieldgemma_checkpoint_path: str,
gemma_checkpoint_path: str,
config: ShieldGemma2Config,
target_dtype: torch.dtype,
) -> ConversionResult:
"""Loads Orbax checkpoint from `input_path` and converts it to HF tree."""
checkpointer = obc.PyTreeCheckpointer()
sg2_ckpt = checkpointer.restore(shieldgemma_checkpoint_path)
g3_ckpt = checkpointer.restore(gemma_checkpoint_path)
hf_tree: dict[str, torch.Tensor] = {}
def update_tree(path: str, weights: np.ndarray) -> None:
torch_tensor = torch.from_numpy(weights.astype("float32")).type(target_dtype)
logging.info(
"%s converted shape=%s with dtype=%s",
path,
weights.shape,
torch_tensor.dtype,
)
hf_tree[f"model.{path}"] = torch_tensor
for paths, value in tree.flatten_with_path(g3_ckpt):
if paths[0].startswith("SigLiPFromPatches_"):
path, weights = convert_siglip_weight(config=config.vision_config, paths=paths, weights=value)
update_tree(path, weights)
for paths, value in tree.flatten_with_path(sg2_ckpt):
for path, weights in convert_transformer_weights(config=config.text_config, paths=paths, weights=value):
update_tree(path, weights)
hf_tree["model.language_model.lm_head.weight"] = hf_tree["model.language_model.model.embed_tokens.weight"]
return ConversionResult(state_tree=hf_tree, config=config)
def main(*args):
del args
dtype = getattr(torch, PRECISION.value)
output_path = OUTPUT_PATH.value
tokenizer = GemmaTokenizerFast(
TOKENIZER_PATH.value,
extra_special_tokens={
"image_token": "<image_soft_token>", # Should be ID=262_144
"boi_token": "<start_of_image>", # Should be ID=255_999
"eoi_token": "<end_of_image>", # Should be ID=256_000
},
)
yes_token_index, no_token_index = torch.tensor(tokenizer(["Yes", "No"])["input_ids"])[:, 1].numpy()
config = ShieldGemma2Config(
yes_token_index=int(yes_token_index),
no_token_index=int(no_token_index),
text_config=Gemma3TextConfig(
vocab_size=262_208,
hidden_size=2560,
intermediate_size=2560 * 8 // 2,
num_attention_heads=8,
head_dim=256,
num_hidden_layers=34,
num_key_value_heads=4,
sliding_window=1024,
rope_scaling={"rope_type": "linear", "factor": 8.0}, # used for global RoPE only
rope_theta=1_000_000,
rope_local_base_freq=10_000,
attn_logit_softcapping=None,
query_pre_attn_scalar=256,
max_position_embeddings=8192,
),
vision_config={
"hidden_size": 1152,
"intermediate_size": 4304,
"num_hidden_layers": 27,
"num_attention_heads": 16,
"num_channels": 3,
"image_size": 896,
"patch_size": 14,
"hidden_act": "gelu_pytorch_tanh",
"layer_norm_eps": 1e-6,
"attention_dropout": 0.0,
"vision_use_head": False,
},
)
config.save_pretrained(output_path)
image_processor = Gemma3ImageProcessor(
image_seq_length=256,
image_mean=(0.5,) * 3,
image_std=(0.5,) * 3,
size={"height": 896, "width": 896},
resample=PILImageResampling.BILINEAR,
)
processor = ShieldGemma2Processor(
image_processor=image_processor,
tokenizer=tokenizer,
policy_definitions=_SHIELDGEMMA2_POLICIES,
)
tokenizer.chat_template = _CHAT_TEMPLATE
processor.chat_template = _CHAT_TEMPLATE
processor.save_pretrained(output_path)
logging.info("Saved Shieldgemma2Processor to %s", output_path)
del processor
del tokenizer
logging.info("Converting Shieldgemma2 @ %s", dtype)
result = convert(_SHIELDGEMMA_CHECKPOINT_PATH.value, _GEMMA_CHECKPOINT_PATH.value, config, dtype)
logging.info("Converted Shieldgemma2 state tree from Orbax to Hugging Face.")
with accelerate.init_empty_weights():
model = ShieldGemma2ForImageClassification(config=config)
model.load_state_dict(result.state_tree, assign=True, strict=True)
model.config.torch_dtype = dtype
logging.info("Loaded Shieldgemma2 in Hugging Face Transformers.")
model.save_pretrained(output_path, safe_serialization=True)
logging.info("Saved Shieldgemma2 to SafeTensors in %s", output_path)
del model
del result
if __name__ == "__main__":
app.run(main)

View File

@@ -0,0 +1,228 @@
# coding=utf-8
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
import torch.utils.checkpoint
from ...cache_utils import Cache
from ...modeling_outputs import ImageClassifierOutputWithNoAttention
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from ...utils.deprecation import deprecate_kwarg
from ..auto import AutoModelForImageTextToText
from .configuration_shieldgemma2 import ShieldGemma2Config
_CHECKPOINT_FOR_DOC = "google/shieldgemma-2-4b-it"
_CONFIG_FOR_DOC = "ShieldGemma2Config"
logger = logging.get_logger(__name__)
SHIELDGEMMA2_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
Returns:
A `ShieldGemma2ImageClassifierOutputWithNoAttention` instance continaing the logits and probabilities
associated with the model predicting the `Yes` or `No` token as the response to that prompt, captured in the
following properties.
* `logits` (`torch.Tensor` of shape `(batch_size, 2)`):
The first position along dim=1 is the logits for the `Yes` token and the second position along dim=1 is
the logits for the `No` token.
* `probabilities` (`torch.Tensor` of shape `(batch_size, 2)`):
The first position along dim=1 is the probability of predicting the `Yes` token and the second position
along dim=1 is the probability of predicting the `No` token.
ShieldGemma prompts are constructed such that predicting the `Yes` token means the content *does violate* the
policy as described. If you are only interested in the violative condition, use
`violated = outputs.probabilities[:, 1]` to extract that slice from the output tensors.
When used with the `ShieldGemma2Processor`, the `batch_size` will be equal to `len(images) * len(policies)`,
and the order within the batch will be img1_policy1, ... img1_policyN, ... imgM_policyN.
"""
@dataclass
class ShieldGemma2ImageClassifierOutputWithNoAttention(ImageClassifierOutputWithNoAttention):
"""ShieldGemma2 classifies imags as violative or not relative to a specific policy
Args:
"""
probabilities: torch.Tensor = None
class ShieldGemma2ForImageClassification(PreTrainedModel):
config_class = ShieldGemma2Config
def __init__(self, config: ShieldGemma2Config):
super().__init__(config=config)
self.yes_token_index = getattr(config, "yes_token_index", 10_784)
self.no_token_index = getattr(config, "no_token_index", 3771)
self.model = AutoModelForImageTextToText.from_config(config=config)
def get_input_embeddings(self):
return self.model.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.model.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.model.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.model.language_model.set_decoder(decoder)
def get_decoder(self):
return self.model.language_model.get_decoder()
def tie_weights(self):
return self.model.language_model.tie_weights()
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(SHIELDGEMMA2_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=ShieldGemma2ImageClassifierOutputWithNoAttention, config_class=_CONFIG_FOR_DOC
)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> ShieldGemma2ImageClassifierOutputWithNoAttention:
"""Predicts the binary probability that the image violates the speicfied policy.
Returns:
"""
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
token_type_ids=token_type_ids,
cache_position=cache_position,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
logits = outputs.logits
selected_logits = logits[:, -1, [self.yes_token_index, self.no_token_index]]
probabilities = torch.softmax(selected_logits, dim=-1)
return ShieldGemma2ImageClassifierOutputWithNoAttention(
logits=selected_logits,
probabilities=probabilities,
)
__all__ = [
"ShieldGemma2ForImageClassification",
]

View File

@@ -0,0 +1,195 @@
# coding=utf-8
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Mapping, Sequence
from typing import Optional
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import Unpack
from ...utils import logging
from ..gemma3.processing_gemma3 import Gemma3Processor, Gemma3ProcessorKwargs
logger = logging.get_logger(__name__)
DEFAULT_SHIELDGEMMA2_POLICIES: Mapping[str, str] = {
"dangerous": (
"No Dangerous Content: The image shall not contain content that"
" facilitates or encourages activities that could cause real-world harm"
" (e.g., building firearms and explosive devices, promotion of"
" terrorism, instructions for suicide)."
),
"sexual": (
"No Sexually Explicit content: The image shall not contain content that"
" depicts explicit or graphic sexual acts (e.g., pornography, erotic"
" nudity, depictions of rape or sexual assault)."
),
"violence": (
"No Violence/Gore content: The image shall not contain content that"
" depicts shocking, sensational, or gratuitous violence (e.g.,"
" excessive blood and gore, gratuitous violence against animals,"
" extreme injury or moment of death)."
),
}
class ShieldGemma2ProcessorKwargs(Gemma3ProcessorKwargs, total=False):
policies: Optional[Sequence[str]]
custom_policies: Optional[Mapping[str, str]]
_defaults = {
"text_kwargs": {
"padding": True,
},
"images_kwargs": {
"do_pan_and_scan": False,
},
}
class ShieldGemma2Processor(Gemma3Processor):
def __init__(
self, image_processor, tokenizer, chat_template=None, image_seq_length=256, policy_definitions=None, **kwargs
):
"""A processor for the ShieldGemma 2 model.
Args:
image_processor: The image processor to use, typically a `Gemma3ImageProcessorFast` instance.
tokenizer: The tokenizer to use, typically a `GemmaTokenizerFast` instance.
chat_template: The chat template to use with this processor. Typically, this is unset as the processor
configuration on Hugging Face Hub includes this value already.
image_seq_length: The number of soft tokens per image. Typically, this is unset as the processor
configuration on Hugging Face Hub includes this value already.
policy_definitions: A mapping from policy name to its description in text used as the default policies to
classify images against. The policy descriptions are included in the text of the prompts generated by
this processor. Typically, this is unset as the processor configuration on Hugging Face Hub includes
the base policies ShieldGemma was trained on.
"""
super().__init__(image_processor, tokenizer, chat_template, image_seq_length, **kwargs)
if policy_definitions is None:
self.policy_definitions = DEFAULT_SHIELDGEMMA2_POLICIES
else:
self.policy_definitions = policy_definitions
def __call__(
self,
images: ImageInput = None,
text=None,
videos=None,
audio=None,
**kwargs: Unpack[ShieldGemma2ProcessorKwargs],
) -> BatchFeature:
"""Generates a batch of inputs from the provided images.
ShieldGemma was trained to classify image content for policy compliance using a specific prompt construction.
This processor generates a batch of such prompts from the provided images by:
1. Creating a list of conversations, one for each `<image, policy>` pair;
2. Converting these conversations to text using `self.apply_chat_template()`; and
3. Encoding the conversations and images using the same techniques as `Gemma3Processor`.
Args:
images: A single image or a list of images to include in the batch.
text: Not supported.
videos: Not supported.
audio: Not supported.
kwargs: An optional dictionary of keyword arguments to configre the
processor. Possible values include:
* `custom_policies`: Additional policy definitions that augment the `self.policy_definitions` passed
into the constructor. Note that `custom_policies` that share a key with `self.policy_definitions`
will override the policy description
* `policies`: (Optional) a list of keys in the joint `self.policy_definitions | custom_policies`
dictionary of specific interest for the provided images. If empty or None, prompts will be
generated for every key in the joint dictionary.
Returns:
A `BatchFeature` continaing `input_ids`, `pixel_values`, etc. where each Tensor is of shape
`(len(images) * len(policies), )`, and the order within the batch will be
img1_policy1, ... img1_policyN, ... imgM_policyN.
"""
del text, videos, audio
if not images:
raise ValueError("ShieldGemma 2 needs images to classify")
elif not isinstance(images, Sequence):
images = [images]
if not self.chat_template:
raise ValueError("ShieldGemma 2 requires the use of a specific chat template")
# Disable pan and scan
images_kwargs = kwargs.setdefault("images_kwargs", {})
if images_kwargs.get("do_pan_and_scan") is True:
logger.warning_once("ShieldGemma2 does not support pan and scan.")
images_kwargs["do_pan_and_scan"] = False
# Enable padding on the batch during tokenization
text_kwargs = kwargs.setdefault("text_kwargs", {})
if "padding" not in text_kwargs:
text_kwargs["padding"] = kwargs.pop("padding", True)
text_kwargs["padding_side"] = kwargs.pop("padding_side", "left")
policy_definitions: Mapping[str, str] = {
**self.policy_definitions,
**kwargs.get("custom_policies", {}),
}
if (policies := kwargs.get("policies")) is None:
policies = list(policy_definitions.keys())
# TODO(ryanmullins): Support images from PIL or URLs.
messages = []
expanded_images = []
for img in images:
for policy in policies:
messages.append(
[
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": policy_definitions[policy]},
],
}
]
)
expanded_images.append([img])
text = self.apply_chat_template(messages, tokenize=False)
return super().__call__(images=expanded_images, text=text, **kwargs)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"]
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
__all__ = ["ShieldGemma2Processor"]

View File

@@ -8870,6 +8870,13 @@ class SEWDPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class ShieldGemma2ForImageClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SiglipForImageClassification(metaclass=DummyObject): class SiglipForImageClassification(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]

View File

View File

@@ -0,0 +1,61 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Gemma3 model."""
import unittest
from io import BytesIO
import requests
from PIL import Image
from transformers import is_torch_available
from transformers.testing_utils import (
cleanup,
require_torch_gpu,
slow,
torch_device,
)
if is_torch_available():
import torch
from transformers import ShieldGemma2ForImageClassification, ShieldGemma2Processor
@slow
@require_torch_gpu
# @require_read_token
class ShieldGemma2IntegrationTest(unittest.TestCase):
def tearDown(self):
cleanup(torch_device, gc_collect=True)
def test_model(self):
model_id = "google/shieldgemma-2-4b-it"
processor = ShieldGemma2Processor.from_pretrained(model_id, padding_side="left")
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
response = requests.get(url)
image = Image.open(BytesIO(response.content))
model = ShieldGemma2ForImageClassification.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
).to(torch_device)
inputs = processor(images=[image]).to(torch_device)
output = model(**inputs)
self.assertEqual(len(output.probabilities), 3)
for element in output.probabilities:
self.assertEqual(len(element), 2)

View File

@@ -0,0 +1,220 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import shutil
import tempfile
import unittest
from collections.abc import Mapping
from parameterized import parameterized
from transformers import GemmaTokenizer, ShieldGemma2Processor
from transformers.testing_utils import get_tests_dir, require_vision
from transformers.utils import is_vision_available
from ...test_processing_common import ProcessorTesterMixin
if is_vision_available():
from transformers import Gemma3ImageProcessor
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
# Copied from _CHAT_TEMPLATE in src/transformers/models/shieldgemma2/convert_shieldgemma2_weights_orbax_to_hf.py
_CHAT_TEMPLATE = """{{ bos_token }}
{%- for message in messages -%}
{%- if message['content'] is iterable -%}
{%- for item in message['content'] -%}
{{ '<start_of_turn>user\n' }}
{%- if item['type'] == 'image' -%}
{{ '<start_of_image>' }}
{%- elif item['type'] == 'text' -%}
{{ 'You are a policy expert trying to help determine whether a generated or real image is in' }}
{{ ' violation of the safety policy below. Your decision must very clearly follow the specific'}}
{{ ' language used in the safety policy and should not include any interpretation outside of the' }}
{{ ' safety policy.\n\n' }}
{{ '<BEGIN_SAFETY_POLICY>\n' }}
{{ item['text'] + '\n' }}
{{ '</END_SAFETY_POLICY>\n\n' }}
{{ 'Does the image violate the above policy? Your Answer must start with \\'Yes\\' or \\'No\\'.' }}
{{ '<end_of_turn>\n' }}
{%- endif -%}
{%- endfor -%}
{{'<start_of_turn>model\n'}}
{%- else -%}
{{ raise_exception("Conversation messages must contain iterable content containing images and policy definitions in text.") }}
{%- endif -%}
{%- endfor -%}
"""
# Simplified from _SHIELDGEMMA2_POLICIES in src/transformers/models/shieldgemma2/convert_shieldgemma2_weights_orbax_to_hf.py
_SHIELDGEMMA2_POLICIES: Mapping[str, str] = {
"dangerous": "Test policy related to dangerous content.",
"sexual": "Test policy related to sexually explicit content.",
"violence": "Test policy related to violent content.",
}
@require_vision
class ShieldGemma2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = ShieldGemma2Processor
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = Gemma3ImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
extra_special_tokens = {
"image_token": "<image_soft_token>",
"boi_token": "<start_of_image>",
"eoi_token": "<end_of_image>",
}
tokenizer = GemmaTokenizer(SAMPLE_VOCAB, keep_accents=True, extra_special_tokens=extra_special_tokens)
processor_kwargs = self.prepare_processor_dict()
processor = ShieldGemma2Processor(image_processor=image_processor, tokenizer=tokenizer, **processor_kwargs)
processor.save_pretrained(self.tmpdirname)
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def prepare_processor_dict(self):
return {
"chat_template": _CHAT_TEMPLATE,
"policy_definitions": _SHIELDGEMMA2_POLICIES,
}
def test_policy_definitions_saved_in_config(self):
processor_config_path = os.path.join(self.tmpdirname, "processor_config.json")
with open(processor_config_path, "rb") as processor_config_file:
json_dict = json.load(processor_config_file)
self.assertIsInstance(json_dict, dict)
self.assertIn("policy_definitions", json_dict)
self.assertIs(len(json_dict["policy_definitions"]), 3)
@parameterized.expand(
[
("all_policies", None, 3),
("selected_policies", ["dangerous", "violence"], 2),
("single_policy", ["sexual"], 1),
]
)
def test_with_default_policies(self, name, policies, expected_batch_size):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
images = self.prepare_image_inputs()
processed_inputs = processor(images=images, policies=policies)
self.assertEqual(len(processed_inputs[self.text_input_name]), expected_batch_size)
self.assertEqual(len(processed_inputs[self.images_input_name]), expected_batch_size)
@parameterized.expand(
[
("all_policies", None, 6),
("selected_policies_from_both", ["cbrne", "dangerous", "specialized_advice", "violence"], 4),
("selected_policies_from_custom", ["cbrne", "specialized_advice"], 2),
("selected_policies_from_default", ["dangerous", "violence"], 2),
("single_policy_from_custom", ["ip"], 1),
("single_policy_from_default", ["sexual"], 1),
]
)
def test_with_custom_policies(self, name, policies, expected_batch_size):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
# Test policies adapated from https://ailuminate.mlcommons.org/benchmarks/ hazard categories
custom_policies = {
"cbrne": "Test policy related to indiscriminate weapons.",
"ip": "Test policy related to intellectual property.",
"specialized_advice": "Test policy related to specialized advice.",
}
images = self.prepare_image_inputs()
processed_inputs = processor(images=images, custom_policies=custom_policies, policies=policies)
self.assertEqual(len(processed_inputs[self.text_input_name]), expected_batch_size)
self.assertEqual(len(processed_inputs[self.images_input_name]), expected_batch_size)
def test_with_multiple_images(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
images = self.prepare_image_inputs(batch_size=2)
print(images)
processed_inputs = processor(images=images)
self.assertEqual(len(processed_inputs[self.text_input_name]), 6)
self.assertEqual(len(processed_inputs[self.images_input_name]), 6)
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
@unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.")
def test_chat_template_accepts_processing_kwargs(self):
pass
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
@unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.")
def test_chat_template_batched(self):
pass
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
@unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.")
def test_chat_template_dict_torch(self):
pass
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
@unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.")
def test_chat_template_single(self):
pass
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
@unittest.skip("Parent test needs to be adapted for ShieldGemma 2.")
def test_unstructured_kwargs_batched(self):
pass
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
@unittest.skip("Parent test needs to be adapted for ShieldGemma 2.")
def test_unstructured_kwargs(self):
pass
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
@unittest.skip("Parent test needs to be adapted for ShieldGemma 2.")
def test_tokenizer_defaults_preserved_by_kwargs(self):
pass
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
@unittest.skip("Parent test needs to be adapted for ShieldGemma 2.")
def test_structured_kwargs_nested_from_dict(self):
pass
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
@unittest.skip("Parent test needs to be adapted for ShieldGemma 2.")
def test_structured_kwargs_nested(self):
pass
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
@unittest.skip("Parent test needs to be adapted for ShieldGemma 2.")
def test_kwargs_overrides_default_tokenizer_kwargs(self):
pass
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
@unittest.skip("Parent test needs to be adapted for ShieldGemma 2.")
def test_kwargs_overrides_default_image_processor_kwargs(self):
pass

View File

@@ -228,6 +228,14 @@ SPECIAL_CASES_TO_ALLOW = {
"GPTNeoXConfig": ["rotary_emb_base"], "GPTNeoXConfig": ["rotary_emb_base"],
"Gemma3Config": ["boi_token_index", "eoi_token_index"], "Gemma3Config": ["boi_token_index", "eoi_token_index"],
"Gemma3TextConfig": ["cache_implementation", "tie_word_embeddings"], "Gemma3TextConfig": ["cache_implementation", "tie_word_embeddings"],
"ShieldGemma2Config": [
"boi_token_index",
"eoi_token_index",
"initializer_range",
"mm_tokens_per_image",
"text_config",
"vision_config",
],
} }

View File

@@ -167,6 +167,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
"models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py", "models/vision_text_dual_encoder/test_modeling_flax_vision_text_dual_encoder.py",
"models/decision_transformer/test_modeling_decision_transformer.py", "models/decision_transformer/test_modeling_decision_transformer.py",
"models/bark/test_modeling_bark.py", "models/bark/test_modeling_bark.py",
"models/shieldgemma2/test_modeling_shieldgemma2.py",
] ]
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and # Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and