* deprecate the prev fix * reword warning and update docs * reword warning * tests * dont bloat `get_text_config()`
8.5 KiB
Gemma3
Overview
The Gemma 3 model was proposed in the Gemma 3 Techncial Report by Google. It is a vision-language model composed by a SigLIP vision encoder and a Gemma 2 language decoder, linked by a multimodal linear projection. It cuts an image into a fixed number of tokens, in the same way as SigLIP, as long as the image does not exceed certain aspect ratio. For images that exceed the given aspect ratio, it crops the image into multiple smaller patches and concatenates them with the base image embedding. One particularity is that the model uses bidirectional attention on all the image tokens. In addition, the model interleaves sliding window local attention with full causal attention in the language backbone, where each sixth layer is a full causal attention layer.
This model was contributed by Ryan Mullins, Raushan Turganbay Arthur Zucker, and Pedro Cuenca.
Usage tips
- For image+text and image-only inputs use
Gemma3ForConditionalGeneration. - For text-only inputs use
Gemma3ForCausalLMfor generation to avoid loading the vision tower. - Each sample can contain multiple images, and the number of images can vary between samples. However, make sure to pass correctly batched images to the processor, where each batch is a list of one or more images.
- The text passed to the processor should have a
<start_of_image>token wherever an image should be inserted. - The processor has its own
apply_chat_templatemethod to convert chat messages to model inputs. See the examples below for more details on how to use it.
Image cropping for high resolution images
The model supports cropping images into smaller patches when the image aspect ratio exceeds a certain value. By default the images are not cropped and only the base image is forwarded to the model. Users can set do_pan_and_scan=True to obtain several crops per image along with the base image to improve the quality in DocVQA or similar tasks requiring higher resolution images.
Pan and scan is an inference time optimization to handle images with skewed aspect ratios. When enabled, it improves performance on tasks related to document understanding, infographics, OCR, etc.
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it", padding_side="left")
url = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
messages = [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."}
]
},
{
"role": "user", "content": [
{"type": "image", "url": url},
{"type": "text", "text": "What is shown in this image?"},
]
},
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
do_pan_and_scan=True,
).to(model.device)
Usage Example
Single-image Inference
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
url = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
messages = [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."}
]
},
{
"role": "user", "content": [
{"type": "image", "url": url},
{"type": "text", "text": "What is shown in this image?"},
]
},
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)
output = model.generate(**inputs, max_new_tokens=50, cache_implementation="static")
print(processor.decode(output[0], skip_special_tokens=True))
Use the AttentionMaskVisualizer to better understand what tokens the model can and cannot attend to.
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
visualizer = AttentionMaskVisualizer("google/gemma-3-4b-it")
visualizer("<img>What is shown in this image?")
Notes
-
Use [
Gemma3ForConditionalGeneration] for image-and-text and image-only inputs. -
Gemma 3 supports multiple input images, but make sure the images are correctly batched before passing them to the processor. Each batch should be a list of one or more images.
url_cow = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4=" url_cat = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg" messages =[ { "role": "system", "content": [ {"type": "text", "text": "You are a helpful assistant."} ] }, { "role": "user", "content": [ {"type": "image", "url": url_cow}, {"type": "image", "url": url_cat}, {"type": "text", "text": "Which image is cuter?"}, ] }, ] -
Text passed to the processor should have a
<start_of_image>token wherever an image should be inserted. -
The processor has its own [
~ProcessorMixin.apply_chat_template] method to convert chat messages to model inputs. -
By default, images aren't cropped and only the base image is forwarded to the model. In high resolution images or images with non-square aspect ratios, artifacts can result because the vision encoder uses a fixed resolution of 896x896. To prevent these artifacts and improve performance during inference, set
do_pan_and_scan=Trueto crop the image into multiple smaller patches and concatenate them with the base image embedding. You can disable pan and scan for faster inference.inputs = processor.apply_chat_template( messages, tokenize=True, return_dict=True, return_tensors="pt", add_generation_prompt=True, + do_pan_and_scan=True, ).to("cuda") -
For Gemma-3 1B checkpoint trained in text-only mode, use [
AutoModelForCausalLM] instead.import torch from transformers import AutoModelForCausalLM, AutoTokenizer tokenizer = AutoTokenizer.from_pretrained( "google/gemma-3-1b-pt", ) model = AutoModelForCausalLM.from_pretrained( "google/gemma-3-1b-pt", torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="sdpa" ) input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=100) text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(text)
## Gemma3ImageProcessor
[[autodoc]] Gemma3ImageProcessor
## Gemma3ImageProcessorFast
[[autodoc]] Gemma3ImageProcessorFast
## Gemma3Processor
[[autodoc]] Gemma3Processor
## Gemma3TextConfig
[[autodoc]] Gemma3TextConfig
## Gemma3Config
[[autodoc]] Gemma3Config
## Gemma3TextModel
[[autodoc]] Gemma3TextModel
- forward
## Gemma3ForCausalLM
[[autodoc]] Gemma3ForCausalLM
- forward
## Gemma3ForConditionalGeneration
[[autodoc]] Gemma3ForConditionalGeneration
- forward