* Gemma 3n

* initial commit of Gemma 3n scaffold

* Fixing param pass through on Gemm3p5RMSNorm

* Adds Einsum layer to Gemma 3n

* Updating EinsumLayer API

* Undoing erroneous force push

* Reverting RMSNorm to with_scale by default

* Adds LAuReL to Gemma 3n

* Adds AltUp to Gemma 3n

* Adding Gemma3p5 overall and text config with vision and audio config placeholders (#3)

* Adding gemma3p5 text configs

* Adding audio config placeholders

* Adding a placeholder for vision configs

* Updating MobileNetVisionConfig, inheriting TimmWrapperConfig

* Updating text configs

* Update src/transformers/models/gemma3p5/modular_gemma3p5.py

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Removing altup configs to accept the suggested configs

* Update src/transformers/models/gemma3p5/modular_gemma3p5.py

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Updating altup config

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Addressing review comments and updating text configs

* Adding a config for activation sparsity

* Updating configs to pass through options to super class init and adjust some name prefixes

* Updating laurel and altup with corrected config values

* Normalizing sub_config initializers

---------

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Updating MLP with activation sparsity (#2)

* Updating DecoderBlock for Gemma 3n (#3)

* Initial Gemm3nTextModel (#4)

NOTE: This implementation WILL CHANGE in the coming weeks, however, changes will be strictly additive and this will remain a suitable baseline for downstream implementations to reference.

* Adding KV Cache Sharing

* Adds Einsum layer to Gemma 3n

* Updating EinsumLayer API

* Refactored kv cache sharing in attention

* Adding KVStore for cache sharing

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Update src/transformers/cache_utils.py

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Undoing erroneous force push

* Reverting RMSNorm to with_scale by default

* Adds LAuReL to Gemma 3n

* Updating KV Cache Sharing implementation

* Updating the q and k norm definitions in the attention module

* Fixing name error for q,k,v RMS norm to use the right 3n module

* Updating MLP with activation sparsity

* Updating DecoderBlock for Gemma 3.5

* Updating kv cache sharing implementation with the use of a cache buffer and refactoring some lines of code

* Isolating KV Cache logic to relevant components

* Fixing logic error in Gemma3nAttention.forward

* Refactoring caching contributions and fixing kv_store initialization

* Simplifying Configs

* Remove errant self from super init call

* Bug fix in the Attention module - changing self.head_dim to config.head_dim

* Bug fixes in the LaurelBlock and RMS Norm super init call

* removing redundant code from a merge

* Adding per_layer_inputs to TextModel

* Adding preprocess embeddings with altup

* Adds per-layer-to-single output and a host of TODOs

* Integrating altup predict with the model workflow and other minor bug fixes

* Using nn.Embedding temporarily for text model

* It goes forward

* Minor refactor of attention sparsity and RoPE initialization

* Fixing duplicate rope_scaling param bug when loading from pretrained

---------

Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com>
Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com>

* Normalizing on altup_num_inputs config option

* regenerating modeling file after syncing to HEAD

* Use torch.std(..., unbiased=False) for activation sparsity (#8)

* Refactoring to a single QVK Norm (#13)

* AltUp: support scale_corrected_output (#14)

* Converts einsums to nn.Linear (#7)

* Converts einsums to nn.Linear

* Removing unused variables

* Aligning SharedKVCache with HybridCache (#11)

* Alinging SharedKVStore with HybridCache

* Remove KVStore. Refactor apply_rotary_pos_emb for sharing

* Addressing review comments

* Supporting split modality embeddings in Gemma3n (#10)

* Adding the Embedder class

* Update modular

Co-authored-by: Ryan Mullins <ryan@ryanmullins.org>

* Update modular

Co-authored-by: Ryan Mullins <ryan@ryanmullins.org>

* Update modular

Co-authored-by: Ryan Mullins <ryan@ryanmullins.org>

* Update modular

Co-authored-by: Ryan Mullins <ryan@ryanmullins.org>

* Update modular

Co-authored-by: Ryan Mullins <ryan@ryanmullins.org>

* Update modular

Co-authored-by: Ryan Mullins <ryan@ryanmullins.org>

* Addressing review comments, adding audio embedding layers, integrating embedder with the remaining architecture, adding a forward method for conditional generation

* Apply suggestions from code review

Co-authored-by: Ryan Mullins <ryan@ryanmullins.org>

* Update modular

Co-authored-by: Ryan Mullins <ryan@ryanmullins.org>

* Addressing review comments, prop drilling audio and vision configs to the text config

* Removing TODO's that have been addressed

* Simplify Embedder init and add audio embeddings

* Embeddings refactor. Adds Gemma3nAudioEmbedder and Gemma3nVisionEmbedder

* Refactoring vision and audio embeddings into ConditionalGeneration model

---------

Co-authored-by: Ryan Mullins <ryan@ryanmullins.org>
Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Updating attention mask for Gemma 3.5 (#15)

* xxx_token_index to xxx_token_id

* remvoing deprecated last_cache_position

* Removing references to SigLIP

* Always init per-layer inputs

* Using torch.finfo().min for epsilon_tensor

* Gemma3nDecoderLayer inherits from Gemma3DecoderLayer. Remove gating lambdas

* fix modular GEMMA3N_INPUTS_DOCSTRING

* Gemma3nAttention inherits from Gemma3Attention

* Modular inheritance fixes

* CausalLM conversion script for 4B model (#16)

* Add Gemma3n Audio Encoder (#6)

* initial commit of Gemma 3.5 scaffold

* Fixing param pass through on Gemm3nRMSNorm

* Adds Einsum layer to Gemma 3.5

* Updating EinsumLayer API

* Undoing erroneous force push

* Reverting RMSNorm to with_scale by default

* Adds LAuReL to Gemma 3n

* Adds AltUp to Gemma 3n

* Adding Gemma3n overall and text config with vision and audio config placeholders (#3)

* Adding gemma3n text configs

* Adding audio config placeholders

* Adding a placeholder for vision configs

* Updating MobileNetVisionConfig, inheriting TimmWrapperConfig

* Updating text configs

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Removing altup configs to accept the suggested configs

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Updating altup config

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Addressing review comments and updating text configs

* Adding a config for activation sparsity

* Updating configs to pass through options to super class init and adjust some name prefixes

* Updating laurel and altup with corrected config values

* Normalizing sub_config initializers

---------

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Updating MLP with activation sparsity (#2)

* Updating DecoderBlock for Gemma 3.5 (#3)

* Initial Gemm3nTextModel (#4)

NOTE: This implementation WILL CHANGE in the coming weeks, however, changes will be strictly additive and this will remain a suitable baseline for downstream implementations to reference.

* Adding KV Cache Sharing

* Adds Einsum layer to Gemma 3.5

* Updating EinsumLayer API

* Refactored kv cache sharing in attention

* Adding KVStore for cache sharing

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Update src/transformers/cache_utils.py

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Undoing erroneous force push

* Reverting RMSNorm to with_scale by default

* Adds LAuReL to Gemma 3n

* Updating KV Cache Sharing implementation

* Updating the q and k norm definitions in the attention module

* Fixing name error for q,k,v RMS norm to use the right Gemma 3n module

* Updating MLP with activation sparsity

* Updating DecoderBlock for Gemma 3.5

* Updating kv cache sharing implementation with the use of a cache buffer and refactoring some lines of code

* Isolating KV Cache logic to relevant components

* Fixing logic error in Gemma3nAttention.forward

* Refactoring caching contributions and fixing kv_store initialization

* Simplifying Configs

* Remove errant self from super init call

* Bug fix in the Attention module - changing self.head_dim to config.head_dim

* Bug fixes in the LaurelBlock and RMS Norm super init call

* removing redundant code from a merge

* Adding per_layer_inputs to TextModel

* Adding preprocess embeddings with altup

* Adds per-layer-to-single output and a host of TODOs

* Integrating altup predict with the model workflow and other minor bug fixes

* Using nn.Embedding temporarily for text model

* It goes forward

* Minor refactor of attention sparsity and RoPE initialization

* Fixing duplicate rope_scaling param bug when loading from pretrained

---------

Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com>
Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com>

* Normalizing on altup_num_inputs config option

* Adding audio encoder config

* Adds high-level components for Audio Encoder

* Implement uniform reducer for Audio Encoder

* Adding placeholders for Conformer components in Audio Encoder

* Adding placeholders for SubSampleConvProjection components in Audio Encoder

* Adding SequenceLayer component placeholders

* Implementing Gemma3nAudioEncoder with nn.Sequential

* Implementing Gemma3nAudioSubSampleConvProjection with nn.Sequential

* Implementing Conformer model with SequenceLayers

* Use OrderedDict in nn.Sequential initializers

* Implements sl.Residual in Torch with nn.Sequential and OrderedDict

* Adopting a base SequenceLayer class with default forward() method

* Implementing sl.GatedLinearUnit in Torch

* Implementing sl.Swish in Torch

* Implementing sl.ReLU in Torch

* Implementing sl.Scale in Torch

* Removing sl.Dropout after tree-shaking

* Implementing sl.RMSNorm in Torch with fake shape

* Implementing sl.GroupNorm in Torch

* Implementing sl.Conv2d in Torch

* Implementing sl.Dense in Torch

* Removing sl.Delay layers, which act as pass-throughs

* Connecting shapes to configs in initializers

* Removing sl.Emit

* Implementing sl.ExpandDims in Torch

* Adding sl.GradientClipping to Torch

* Implementing sl.DenseShaped in Torch

* Implementing sl.LDPA in Torch

* Removing unused sl.CombinedQKVProj class

* Fixing erroneous type hint

* Implemnenting sl.DepthwiseConv1D in Torch

* Implementing sl.MaskInvalid in Torch

* Fixes for initialization

* Fixes for saving weights

* Removing einsums per feedback from HF staff

* Removing Sequence Layers idioms from audio encoder

* Fixes for reviewer comments

* CausalLM conversion script for 4B model

* inv_timescales to non-persistent buffer

* Addressing audio encoder Attention feedback

* Addressing Gemma3nAudioSSCPConvBlock feedback

* Addressing Gemma3nAudioConformerAttention feedback

* Addressing padding feedback

* Weights conversion loads audio state dict

* Always use vision_config so saving works

* Token id updates for configs

* Stubs for interleaving audio embs

* Addressing reviewer feedback

---------

Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com>
Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com>

* Fixing cache access error

* Removing duplicate code from a bad merge

* Gemma 3n Text + Vision Part 1 (#17)

* testing utilities for numerics comparisons

* Corrected einsum to nn.Linear weights conversion

* Inherit scaled word embs from Gemma3 not Bart

* Fixing transposes for collapsed linears

* More transpose fixes

* numpy api fix

* RMSNorm: Explicit kwargs, scale_shift=0.0 when with_scale=True

* Force AltUp  to float32

* Updating debugging script for AudioEncoder debugging

* Support divide_weight_by_sqrt_fan_in from JAX for per-layer inputs

* Correcting attention einsum conversions

* RMSNorm in type of x

* Fixing douplicate laurel norm/gating

* KV sharing using the right previous indices

* Refactor kv shared index computation. Correct frac_shared_layers

* Use num_shared_layers instead of inferring from a fraction

* fixing a bug for logging

* Fix shared data_ptrs in altup inits

* rope: adjust proj -> norm -> rope to preserve computation (#20)

* rope: adjust proj -> norm -> rope to preserve computation

* Removing some breaking language model fluff in ConditionalGeneration

* Consolidate query_states transforms

---------

Co-authored-by: Douglas Reid <21148125+douglas-reid@users.noreply.github.com>
Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Vectorize the loops in AltUp (#19)

* Vectorize the loops in AltUp

* fix typo

* Expanding to support batched inputs

* remove extra debug script

* Fix AltUp.forward

---------

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Add 'scale_shift=0.0, with_scale=True' to the final norm in TextModel

* Convert norm to 1/sqrt (#21)

* Convert norm to 1/sqrt

* Scale shift change per Phil's rec

* Adding default activation sparsity

* Fixing 2B config in weights conversion script

* Fixing RMSNorm parameters - adding scale_shift and with_scale

* Correcting query pre-attention scaling

* Adding query_rescale_scalar to text config

* Adding layer_idx to MLP

* Permafix for input_layernorm

* Use 1/sqrt instead of rsqrt in DecoderLayer

* Fix o_proj conversion

* Conversion script update for vision encoder

* Removing logging for debugging timm model

* Fixing bugs in Gemma3nForConditionalGeneration for text generation

* Generating the modeling_gemma3n.py file

* Removing the addition of an erroneous line in the modeling file

* Adding gemma3n text model to modeling_auto

* Bugfix: Updating the interleaving of inputs_embeds and vision_embeds

* Updating the modeling file with the latest bugfix changes

* Updating models/auto for Gemma 3n

* using AutoTokenizer in forward test

* Adding processing_gemma3n.py

* Gemma 3n configured for AutoModel. Conversion script updated.

* Removing errant merge artifacts

---------

Co-authored-by: Mayank Chaturvedi <imayank@google.com>
Co-authored-by: Douglas Reid <douglas-reid@users.noreply.github.com>
Co-authored-by: Douglas Reid <21148125+douglas-reid@users.noreply.github.com>
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com>

* Removing errant debugging statements from Gemma 3

* Gemma3n audio model (#18)

* testing utilities for numerics comparisons

* Implement CumulativeGroupNorm and add to SubSampleConvProjection and SSCPConvBlock

* Add audio version of forward script based on RyanMullins' implementation

* Updating to match encoder tests. WIP: config question needs resolving

* Updates to audio classes to enable end-to-end running

* Removing vestigial classes, cleaning up print statements

* Adding SiLU / Swish to audio conformer feed forward block

* Shifted Gemma3p5Audio naming prefix to Gemma3NanoAudio

* Adding outputs to audio test

* Fixes to padding in SSCP and 1D convolution, align RMS Norm with wider model

* Update forward test to load from local weights

* Update conversion to process / output audio layers

* Update __all__ to export audio encoder

* AutoModel registration for Gemma 3n Audio

* Use AutoModel for ConditionalGeneration.audio_tower

* Fixing input_proj_linear transpose

* Fixing Gemma3NanoAudioConformerAttention.post conversion

* Fixing Gemma3NanoAudioSSCPConvBlock.conv weights conversion

* Correcting indentation issue on Gemma3p5RMSNorm

---------

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Text + Vision Part 2 (#23)

* Updates for ConditionalGeneration.get_image_features

* Adding a WIP draft of image_processing_gemma3p5.py

* Update src/transformers/models/gemma3p5/modular_gemma3p5.py

Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com>

* Modular conversion after github suggested change

* Text + image gives good results

* Fixing image size preset

* Updating configs for the 2B variant in the conversion script

* Using final generation config in conversion script

---------

Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com>
Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com>

* Audio Integration (#12)

* initial commit of Gemma 3n scaffold

* Fixing param pass through on Gemm3nRMSNorm

* Adds Einsum layer to Gemma 3n

* Updating EinsumLayer API

* Undoing erroneous force push

* Reverting RMSNorm to with_scale by default

* Adds LAuReL to Gemma 3n

* Adds AltUp to Gemma 3n

* Adding Gemma 3n overall and text config with vision and audio config placeholders (#3)

* Adding Gemma 3n text configs

* Adding audio config placeholders

* Adding a placeholder for vision configs

* Updating MobileNetVisionConfig, inheriting TimmWrapperConfig

* Updating text configs

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Removing altup configs to accept the suggested configs

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Updating altup config

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Addressing review comments and updating text configs

* Adding a config for activation sparsity

* Updating configs to pass through options to super class init and adjust some name prefixes

* Updating laurel and altup with corrected config values

* Normalizing sub_config initializers

---------

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Updating MLP with activation sparsity (#2)

* Updating DecoderBlock for Gemma 3n (#3)

* Initial Gemma3nTextModel (#4)

NOTE: This implementation WILL CHANGE in the coming weeks, however, changes will be strictly additive and this will remain a suitable baseline for downstream implementations to reference.

* Adding KV Cache Sharing

* Adds Einsum layer to Gemma 3n

* Updating EinsumLayer API

* Refactored kv cache sharing in attention

* Adding KVStore for cache sharing

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Update modular

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Update src/transformers/cache_utils.py

Co-authored-by: Ryan Mullins <ryanmullins@google.com>

* Undoing erroneous force push

* Reverting RMSNorm to with_scale by default

* Adds LAuReL to Gemma 3n

* Updating KV Cache Sharing implementation

* Updating the q and k norm definitions in the attention module

* Fixing name error for q,k,v RMS norm to use the right 3n module

* Updating MLP with activation sparsity

* Updating DecoderBlock for Gemma 3n

* Updating kv cache sharing implementation with the use of a cache buffer and refactoring some lines of code

* Isolating KV Cache logic to relevant components

* Fixing logic error in Gemma3nAttention.forward

* Refactoring caching contributions and fixing kv_store initialization

* Simplifying Configs

* Remove errant self from super init call

* Bug fix in the Attention module - changing self.head_dim to config.head_dim

* Bug fixes in the LaurelBlock and RMS Norm super init call

* removing redundant code from a merge

* Adding per_layer_inputs to TextModel

* Adding preprocess embeddings with altup

* Adds per-layer-to-single output and a host of TODOs

* Integrating altup predict with the model workflow and other minor bug fixes

* Using nn.Embedding temporarily for text model

* It goes forward

* Minor refactor of attention sparsity and RoPE initialization

* Fixing duplicate rope_scaling param bug when loading from pretrained

---------

Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com>
Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com>

* Normalizing on altup_num_inputs config option

* Adding audio encoder config

* Adds high-level components for Audio Encoder

* Implement uniform reducer for Audio Encoder

* Adding placeholders for Conformer components in Audio Encoder

* Adding placeholders for SubSampleConvProjection components in Audio Encoder

* Adding SequenceLayer component placeholders

* Implementing Gemma3nAudioEncoder with nn.Sequential

* Implementing Gemma3nAudioSubSampleConvProjection with nn.Sequential

* Implementing Conformer model with SequenceLayers

* Use OrderedDict in nn.Sequential initializers

* Implements sl.Residual in Torch with nn.Sequential and OrderedDict

* Adopting a base SequenceLayer class with default forward() method

* Implementing sl.GatedLinearUnit in Torch

* Implementing sl.Swish in Torch

* Implementing sl.ReLU in Torch

* Implementing sl.Scale in Torch

* Removing sl.Dropout after tree-shaking

* Implementing sl.RMSNorm in Torch with fake shape

* Implementing sl.GroupNorm in Torch

* Implementing sl.Conv2d in Torch

* Implementing sl.Dense in Torch

* Removing sl.Delay layers, which act as pass-throughs

* Connecting shapes to configs in initializers

* Removing sl.Emit

* Implementing sl.ExpandDims in Torch

* Adding sl.GradientClipping to Torch

* Implementing sl.DenseShaped in Torch

* Implementing sl.LDPA in Torch

* Removing unused sl.CombinedQKVProj class

* Fixing erroneous type hint

* Implemnenting sl.DepthwiseConv1D in Torch

* Implementing sl.MaskInvalid in Torch

* Fixes for initialization

* Fixes for saving weights

* Removing einsums per feedback from HF staff

* Removing Sequence Layers idioms from audio encoder

* Fixes for reviewer comments

* Converting sl.Frontend to FeatureExtractor

* Updates for ConditionalGeneration.get_image_features

* Adding a WIP draft of image_processing_gemma3n.py

* Update modular

Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com>

* Modular conversion after github suggested change

* Text + image gives good results

* Fixing image size preset

* Draft of audio data in chat template

* Removing image processing. Using SigLIP instead.

* Audio input going end-to-end

* Fixing dtype issues in audio encoder

* x-lib formatting consistency

* Adding example data

* Save preprocessor_config.json from conversion script

* Instrumentaiton for debugging

* Additional instrumentation for preprocessing debugging

* Updates to preprocessor, padding; produces correct end-to-end results on sample

* Tackling configuraiton TODOs

* Start of feature extractor refatcor

* Adds Numpy version of USM extractor, removes Torch version and dependencies

* Fixing AltUp.correct coef permute

* Supporting batches of single audio segment inputs

* Docstrings updates for config

* In-lining audio feature extraction

* Adjustments to conversion script and smoke test script

---------

Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com>
Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com>
Co-authored-by: pculliton <phillipculliton@gmail.com>

* Gemma 3n renaming

* Removing test data and utilities

* Renaming test files

* Gemma 3n refactor

* Fix tokenizer config in conversion script

* Address reviewer feedback

* FeatureExtractor returns float32 by default

* Adding basic tests for audio, and input name for audio encoder

* Audio integration test, updates to model_id for other integration tests

* Use scales for q and k norms (#26)

* Update audio integration test to use HF dataset

* Reviewer feedback

* Expand embedding table to full vocab size in weights conversion

* Mix-n-match MatFormers for Gemma 3n (#25)

* Remove in-place operations (#30)

* chore: removing inplace ops

* remove [tensor] * n pattern

* chore: reviewer feedback in AudioEncoder and AltUp

* More grad clipping

* Dynamo compatibility

* fix: cache slicing error

* chore: simplify shared kv cache slicing

* chore: vision encoder rename in timm

* fix: image processor do_normalize=False

* fixup: style

* chore: model_doc

* fix: docs for code quality

* chore: repo consistency

* fix: RMSNorm in float as in prior Gemmas

* fix: per_layer_inputs = None

* chore: Gemma3nForCausalLM from Gemma3nForConditionalGeneration checkpoint

* chore: repo consistency

* Add initial unit tests for Gemma3nAudioFeatureExtractor (#27)

* Add initial unit tests for Gemma3nAudioFeatureExtractor

* Add basic unit tests for Gemma3nProcessor (#28)

Co-authored-by: Douglas Reid <21148125+douglas-reid@users.noreply.github.com>

* parameterize tests

---------

Co-authored-by: Douglas Reid <21148125+douglas-reid@users.noreply.github.com>

* chore: code style

* fix: test cases

* style and consistency

* fix config in the test to be coherent with layer cache sharing

* fix hidden states in tests and code

* inits and mappings

* fix modality prefixes

* test order and prefixes

* fix test exception

* fix class order and reduce model size for faster tests

* restore _checkpoint_conversion_mapping to load Caual from Conditional

* fix config mapping!

* fix: reviewer feedback

---------

Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com>
Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com>
Co-authored-by: raushan <raushan@huggingface.co>
Co-authored-by: Mayank Chaturvedi <imayank@google.com>
Co-authored-by: Douglas Reid <douglas-reid@users.noreply.github.com>
Co-authored-by: Douglas Reid <21148125+douglas-reid@users.noreply.github.com>
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
Co-authored-by: pculliton <phillipculliton@gmail.com>
Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com>
Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>

* fix import test

* add model args

* auto_docstring

* replace test path

* consistency

* skip tests for now

* fix docstring for doc builder

* skip unused attr

---------

Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com>
Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com>
Co-authored-by: raushan <raushan@huggingface.co>
Co-authored-by: Mayank Chaturvedi <imayank@google.com>
Co-authored-by: Douglas Reid <douglas-reid@users.noreply.github.com>
Co-authored-by: Douglas Reid <21148125+douglas-reid@users.noreply.github.com>
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
Co-authored-by: pculliton <phillipculliton@gmail.com>
Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com>
Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
Co-authored-by: Arthur <arthur.zucker@gmail.com>
This commit is contained in:
Ryan Mullins
2025-06-26 11:55:47 -04:00
committed by GitHub
parent 3e5cc12855
commit c63cfd6a83
22 changed files with 8723 additions and 0 deletions

View File

@@ -959,6 +959,8 @@
title: FLAVA title: FLAVA
- local: model_doc/gemma3 - local: model_doc/gemma3
title: Gemma3 title: Gemma3
- local: model_doc/gemma3n
title: Gemma3n
- local: model_doc/git - local: model_doc/git
title: GIT title: GIT
- local: model_doc/glm4v - local: model_doc/glm4v

View File

@@ -0,0 +1,204 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>
</div>
# Gemma3n
## Overview
Gemma3n is a multimodal model with pretrained and instruction-tuned variants, available in E4B and E2B sizes. While
large portions of the language model architecture are shared with prior Gemma releases, there are many new additions in
this model, including [Alternating Updates][altup] (AltUp), [Learned Augmented Residual Layer][laurel] (LAuReL),
[MatFormer][matformer], Per-Layer Embeddings (PLE), activation sparsity, and KV cache sharing. The language model uses
a similar attention pattern to [Gemma 3](./gemma3.md) with alternating 4 local sliding window self-attention layers for
every global self-attention layer with a maximum context length of 32k tokens. Gemma 3n introduces
[MobileNet v5][mobilenetv5] as the vision encoder, using a default resolution of 768x768 pixels, and adds a
[Universal Speech Model][usm] (USM) as the audio encoder.
The instruction-tuned variant was post-trained with knowledge distillation and reinforcement learning.
You can find all the original Gemma 3n checkpoints under the [Gemma 3n][gemma3n-collection] release.
> [!TIP]
> Click on the Gemma 3n models in the right sidebar for more examples of how to apply Gemma to different vision, audio,
> and language tasks.
The example below demonstrates how to generate text based on an image with [`Pipeline`] or the [`AutoModel`] class.
<hfoptions id="usage">
<hfoption id="Pipeline">
```py
import torch
from transformers import pipeline
pipeline = pipeline(
task="image-text-to-text",
model="google/gemma-3n-e4b",
device=0,
torch_dtype=torch.bfloat16
)
pipeline(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
text="<start_of_image> What is shown in this image?"
)
```
</hfoption>
<hfoption id="AutoModel">
```py
import torch
from transformers import AutoProcessor, Gemma3nForConditionalGeneration
model = Gemma3nForConditionalGeneration.from_pretrained(
"google/gemma-3n-e4b-it",
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa"
)
processor = AutoProcessor.from_pretrained(
"google/gemma-3n-e4b-it",
padding_side="left"
)
messages = [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."}
]
},
{
"role": "user", "content": [
{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
{"type": "text", "text": "What is shown in this image?"},
]
},
]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to("cuda")
output = model.generate(**inputs, max_new_tokens=50, cache_implementation="static")
print(processor.decode(output[0], skip_special_tokens=True))
```
</hfoption>
<hfoption id="transformers CLI">
```bash
echo -e "Plants create energy through a process known as" | transformers run --task text-generation --model google/gemma-3n-e2b --device 0
```
</hfoption>
</hfoptions>
## Notes
- Use [`Gemma3nForConditionalGeneration`] for image-audio-and-text, image-and-text, image-and-audio, audio-and-text,
image-only and aduio-only inputs.
- Gemma 3n supports multiple images per input, but make sure the images are correctly batched before passing them to
the processor. Each batch should be a list of one or more images.
```py
url_cow = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
url_cat = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
messages =[
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."}
]
},
{
"role": "user",
"content": [
{"type": "image", "url": url_cow},
{"type": "image", "url": url_cat},
{"type": "text", "text": "Which image is cuter?"},
]
},
]
```
- Text passed to the processor should have a `<image_soft_token>` token wherever an image should be inserted.
- Gemma 3n accept at most one target audio clip per input, though multiple audio clips can be provided in few-shot
prompts, for example.
- Text passed to the processor should have a `<audio_soft_token>` token wherever an audio clip should be inserted.
- The processor has its own [`~ProcessorMixin.apply_chat_template`] method to convert chat messages to model inputs.
## Gemma3nAudioFeatureExtractor
[[autodoc]] Gemma3nAudioFeatureExtractor
## Gemma3nProcessor
[[autodoc]] Gemma3nProcessor
## Gemma3nTextConfig
[[autodoc]] Gemma3nTextConfig
## Gemma3nVisionConfig
[[autodoc]] Gemma3nVisionConfig
## Gemma3nAudioConfig
[[autodoc]] Gemma3nAudioConfig
## Gemma3nConfig
[[autodoc]] Gemma3nConfig
## Gemma3nTextModel
[[autodoc]] Gemma3nTextModel
- forward
## Gemma3nModel
[[autodoc]] Gemma3nModel
- forward
## Gemma3nForCausalLM
[[autodoc]] Gemma3nForCausalLM
- forward
## Gemma3nForConditionalGeneration
[[autodoc]] Gemma3nForConditionalGeneration
- forward
[altup]: https://proceedings.neurips.cc/paper_files/paper/2023/hash/f2059277ac6ce66e7e5543001afa8bb5-Abstract-Conference.html
[attention-mask-viz]: https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139
[gemma3n-collection]: https://huggingface.co/collections/google/gemma-3n
[laurel]: https://arxiv.org/abs/2411.07501
[matformer]: https://arxiv.org/abs/2310.07707
[usm]: https://arxiv.org/abs/2303.01037

View File

@@ -140,6 +140,10 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("gemma2", "Gemma2Config"), ("gemma2", "Gemma2Config"),
("gemma3", "Gemma3Config"), ("gemma3", "Gemma3Config"),
("gemma3_text", "Gemma3TextConfig"), ("gemma3_text", "Gemma3TextConfig"),
("gemma3n", "Gemma3nConfig"),
("gemma3n_audio", "Gemma3nAudioConfig"),
("gemma3n_text", "Gemma3nTextConfig"),
("gemma3n_vision", "Gemma3nVisionConfig"),
("git", "GitConfig"), ("git", "GitConfig"),
("glm", "GlmConfig"), ("glm", "GlmConfig"),
("glm4", "Glm4Config"), ("glm4", "Glm4Config"),
@@ -518,6 +522,10 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
("gemma2", "Gemma2"), ("gemma2", "Gemma2"),
("gemma3", "Gemma3ForConditionalGeneration"), ("gemma3", "Gemma3ForConditionalGeneration"),
("gemma3_text", "Gemma3ForCausalLM"), ("gemma3_text", "Gemma3ForCausalLM"),
("gemma3n", "Gemma3nForConditionalGeneration"),
("gemma3n_audio", "Gemma3nAudioEncoder"),
("gemma3n_text", "Gemma3nForCausalLM"),
("gemma3n_vision", "TimmWrapperModel"),
("git", "GIT"), ("git", "GIT"),
("glm", "GLM"), ("glm", "GLM"),
("glm4", "GLM4"), ("glm4", "GLM4"),
@@ -839,6 +847,9 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str](
("clip_text_model", "clip"), ("clip_text_model", "clip"),
("aria_text", "aria"), ("aria_text", "aria"),
("gemma3_text", "gemma3"), ("gemma3_text", "gemma3"),
("gemma3n_audio", "gemma3n"),
("gemma3n_text", "gemma3n"),
("gemma3n_vision", "gemma3n"),
("glm4v_text", "glm4v"), ("glm4v_text", "glm4v"),
("idefics3_vision", "idefics3"), ("idefics3_vision", "idefics3"),
("siglip_vision_model", "siglip"), ("siglip_vision_model", "siglip"),

View File

@@ -61,6 +61,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
("dpt", "DPTFeatureExtractor"), ("dpt", "DPTFeatureExtractor"),
("encodec", "EncodecFeatureExtractor"), ("encodec", "EncodecFeatureExtractor"),
("flava", "FlavaFeatureExtractor"), ("flava", "FlavaFeatureExtractor"),
("gemma3n", "Gemma3nAudioFeatureExtractor"),
("glpn", "GLPNFeatureExtractor"), ("glpn", "GLPNFeatureExtractor"),
("granite_speech", "GraniteSpeechFeatureExtractor"), ("granite_speech", "GraniteSpeechFeatureExtractor"),
("groupvit", "CLIPFeatureExtractor"), ("groupvit", "CLIPFeatureExtractor"),

View File

@@ -88,6 +88,7 @@ else:
("focalnet", ("BitImageProcessor", "BitImageProcessorFast")), ("focalnet", ("BitImageProcessor", "BitImageProcessorFast")),
("fuyu", ("FuyuImageProcessor",)), ("fuyu", ("FuyuImageProcessor",)),
("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")), ("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
("gemma3n", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("glm4v", ("Glm4vImageProcessor", "Glm4vImageProcessorFast")), ("glm4v", ("Glm4vImageProcessor", "Glm4vImageProcessorFast")),
("glpn", ("GLPNImageProcessor",)), ("glpn", ("GLPNImageProcessor",)),

View File

@@ -132,6 +132,10 @@ MODEL_MAPPING_NAMES = OrderedDict(
("gemma2", "Gemma2Model"), ("gemma2", "Gemma2Model"),
("gemma3", "Gemma3Model"), ("gemma3", "Gemma3Model"),
("gemma3_text", "Gemma3TextModel"), ("gemma3_text", "Gemma3TextModel"),
("gemma3n", "Gemma3nModel"),
("gemma3n_audio", "Gemma3nAudioEncoder"),
("gemma3n_text", "Gemma3nTextModel"),
("gemma3n_vision", "TimmWrapperModel"),
("git", "GitModel"), ("git", "GitModel"),
("glm", "GlmModel"), ("glm", "GlmModel"),
("glm4", "Glm4Model"), ("glm4", "Glm4Model"),
@@ -583,6 +587,8 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("gemma2", "Gemma2ForCausalLM"), ("gemma2", "Gemma2ForCausalLM"),
("gemma3", "Gemma3ForConditionalGeneration"), ("gemma3", "Gemma3ForConditionalGeneration"),
("gemma3_text", "Gemma3ForCausalLM"), ("gemma3_text", "Gemma3ForCausalLM"),
("gemma3n", "Gemma3nForConditionalGeneration"),
("gemma3n_text", "Gemma3nForCausalLM"),
("git", "GitForCausalLM"), ("git", "GitForCausalLM"),
("glm", "GlmForCausalLM"), ("glm", "GlmForCausalLM"),
("glm4", "Glm4ForCausalLM"), ("glm4", "Glm4ForCausalLM"),
@@ -906,6 +912,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
("emu3", "Emu3ForConditionalGeneration"), ("emu3", "Emu3ForConditionalGeneration"),
("fuyu", "FuyuForCausalLM"), ("fuyu", "FuyuForCausalLM"),
("gemma3", "Gemma3ForConditionalGeneration"), ("gemma3", "Gemma3ForConditionalGeneration"),
("gemma3n", "Gemma3nForConditionalGeneration"),
("git", "GitForCausalLM"), ("git", "GitForCausalLM"),
("glm4v", "Glm4vForConditionalGeneration"), ("glm4v", "Glm4vForConditionalGeneration"),
("got_ocr2", "GotOcr2ForConditionalGeneration"), ("got_ocr2", "GotOcr2ForConditionalGeneration"),

View File

@@ -66,6 +66,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("flava", "FlavaProcessor"), ("flava", "FlavaProcessor"),
("fuyu", "FuyuProcessor"), ("fuyu", "FuyuProcessor"),
("gemma3", "Gemma3Processor"), ("gemma3", "Gemma3Processor"),
("gemma3n", "Gemma3nProcessor"),
("git", "GitProcessor"), ("git", "GitProcessor"),
("glm4v", "Glm4vProcessor"), ("glm4v", "Glm4vProcessor"),
("got_ocr2", "GotOcr2Processor"), ("got_ocr2", "GotOcr2Processor"),

View File

@@ -236,6 +236,20 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
"GemmaTokenizerFast" if is_tokenizers_available() else None, "GemmaTokenizerFast" if is_tokenizers_available() else None,
), ),
), ),
(
"gemma3n",
(
"GemmaTokenizer" if is_sentencepiece_available() else None,
"GemmaTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"gemma3n_text",
(
"GemmaTokenizer" if is_sentencepiece_available() else None,
"GemmaTokenizerFast" if is_tokenizers_available() else None,
),
),
("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("glm4", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("glm4", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),

View File

@@ -0,0 +1,29 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_gemma3n import *
from .feature_extraction_gemma3n import *
from .modeling_gemma3n import *
from .processing_gemma3n import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@@ -0,0 +1,680 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/gemma3n/modular_gemma3n.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_gemma3n.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from typing import Any, Optional, Union
from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...modeling_rope_utils import rope_config_validation
from ...utils import is_timm_available, logging, requires_backends
if is_timm_available():
from timm.data import ImageNetInfo, infer_imagenet_subset
logger = logging.get_logger(__name__)
class Gemma3nTextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Gemma3nTextModel`]. It is used to instantiate an
Gemma3nTextModel model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B, e.g.
[google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
Configuration objects that inherit from [`Gemma3nTextConfig`] and can be used to control the model outputs. Read
the documentation from [`Gemma3nTextConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 262400):
Vocabulary size of the Gemma3nText model. Defines the number of different tokens that can be represented by
the `inputs_ids` passed when calling [`Gemma3nTextModel`]
vocab_size_per_layer_input (`int`, *optional*, defaults to 262144):
Vocabulary size of the per-layer text embeddings that augment the standard embeddings.
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
hidden_size_per_layer_input (`int`, *optional*, defaults to 256):
Dimension of the hidden representations for per-layer emebeddings.
intermediate_size (`int` or `Sequence[int]`, *optional*, defaults to 16384):
Dimension of the MLP representations. MatFormer configurations may wish to provide a sequence of integers
to account for vairable intermediate_size values across layers. In such cases,
`len(intermediate_size) == num_hidden_layers`.
num_hidden_layers (`int`, *optional*, defaults to 35):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 2):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout this
[paper](https://arxiv.org/pdf/2305.13245.pdf). If not specified, will default to `num_attention_heads`.
head_dim (`int`, *optional*, defaults to 256):
The attention head dimension.
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the decoder. Will default to
`"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"`
activation function.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
eos_token_id (`int`, *optional*, defaults to 1):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 2):
Beginning of stream token id.
rope_theta (`float`, *optional*, defaults to 1000000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention.
NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we
recommend you to update this value accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
rope_local_base_freq (float, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings for local attention.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
sliding_window (`int`, *optional*, defaults to 512):
This is the size of the sliding window used by local attention layers.
layer_types (`Optional`, *optional*):
A sequence of strings defining the attention type for that layer as either "sliding_attention" or
"full_attention". If not provided, `layer_types` will de inferred from `num_hidden_layers` using a pattern
of four "sliding_attention" layers followed one "full_attention". The last layer in the model should always
be a "full_attention" layer.
final_logit_softcapping (`float`, *optional*, defaults to 30.0):
Scaling factor when applying tanh softcapping on the logits.
altup_active_idx (`int`, *optional*, defaults to 0):
The index of the prediction from which AltUp will compute additional predictions or correct
altup_coef_clip (`float`, *optional*, defaults to 120.0):
The maximum amplitude of an AltUp prediction or correction coeficient weight.
altup_correct_scale (`bool`, *optional*, defaults to `True`):
If True, apply the `AltUp.correct_output_scale` to the corrected prediction at `altup_active_idx`.
altup_num_inputs (`int`, *optional*, defaults to 4):
The number of predictions that AltUp should be make given the input sequence.
num_kv_shared_layers (`int`, *optional*, defaults to 15):
The number of layer that share KV cache values. During the forward pass, the last `num_kv_shared_layers`
layers in the model "share" the KV values in that each local and global layer in this range uses the KV
cache values computed for the last local or global layer, respectively, before entering this range. The
value should be `num_kv_shared_layers` should be a scalar of `sliding_window_pattern`.
laurel_rank (int, *optional*, defaults to 64):
The intermediate size for the linear projections in the Learned Augmented Residual Layer.
activation_sparsity_pattern (Sequence[float], *optional*, defaults to `(0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)`):
The sparsity factor used to extract the top-k activations for a given layer. The provided Sequence must
explicitly provide a sparsity value for each layer in the model.
```python
>>> from transformers import Gemma3nTextModel, Gemma3nTextConfig
>>> # Initializing a Gemma3nText gemma3n_text-E4B style configuration
>>> configuration = Gemma3nTextConfig()
>>> # Initializing a model from the gemma3n_text-E4B style configuration
>>> model = Gemma3nTextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "gemma3n_text"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size: int = 262_400,
vocab_size_per_layer_input: int = 262_144,
hidden_size: int = 2048,
hidden_size_per_layer_input: int = 256,
intermediate_size: Union[int, Sequence[int]] = 16_384,
num_hidden_layers: int = 35,
num_attention_heads: int = 8,
num_key_value_heads: int = 2,
head_dim: int = 256,
hidden_activation: str = "gelu_pytorch_tanh",
max_position_embeddings: int = 32_768,
initializer_range: float = 0.02,
rms_norm_eps: float = 1e-6,
use_cache: bool = True,
pad_token_id: int = 0,
eos_token_id: int = 1,
bos_token_id: int = 2,
rope_theta: float = 1_000_000.0,
rope_scaling: Optional[dict[str, Any]] = None,
rope_local_base_freq: float = 10_000.0,
attention_bias: bool = False,
attention_dropout: float = 0.0,
sliding_window: int = 512,
layer_types: Optional[Sequence[str]] = None,
final_logit_softcapping: float = 30.0,
altup_active_idx: int = 0,
altup_coef_clip: float = 120.0,
altup_correct_scale: bool = True,
altup_num_inputs: int = 4,
num_kv_shared_layers: int = 15,
laurel_rank: int = 64,
activation_sparsity_pattern: Optional[Union[float, Sequence[float]]] = (0.95,) * 10 + (0.0,) * 25,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs,
)
if isinstance(intermediate_size, Sequence) and (intsize_len := len(intermediate_size)) != num_hidden_layers:
raise ValueError(
"intermediate_size must have an explicit intermediate size for every layer or one for all layers. "
f"Expected {num_hidden_layers} values but got {intsize_len}."
)
elif not isinstance(intermediate_size, Sequence):
intermediate_size = [intermediate_size] * num_hidden_layers
self.vocab_size = vocab_size
self.vocab_size_per_layer_input = vocab_size_per_layer_input
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.num_key_value_heads = num_key_value_heads
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.hidden_activation = hidden_activation
self.sliding_window = sliding_window
self.final_logit_softcapping = final_logit_softcapping
self.layer_types = layer_types
self.rope_local_base_freq = rope_local_base_freq
self.rope_scaling = rope_scaling
rope_config_validation(self)
if layer_types is None:
self.layer_types = [
"full_attention" if i % 5 == 0 else "sliding_attention" for i in range(self.num_hidden_layers)
]
else:
self.layer_types = layer_types
layer_type_validation(self.layer_types)
self.hidden_size_per_layer_input = hidden_size_per_layer_input
self.num_kv_shared_layers = num_kv_shared_layers
self.altup_active_idx = altup_active_idx
self.altup_coef_clip = altup_coef_clip
self.altup_correct_scale = altup_correct_scale
self.altup_num_inputs = altup_num_inputs
self.laurel_rank = laurel_rank
if activation_sparsity_pattern is None:
activation_sparsity_pattern = [0.0] * num_hidden_layers
if (len_asp := len(activation_sparsity_pattern)) != num_hidden_layers:
raise ValueError(
"activation_sparsity_pattern must have an explicit activation sparsity value for every layer."
f"Expected {num_hidden_layers} values but got {len_asp}."
)
self.activation_sparsity_pattern = activation_sparsity_pattern
class Gemma3nAudioConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Gemma3nAudioEncoder`], based on Gogole's
[Universal Speech Model](). It is used to instantiate an Gemma3nAudioEncoder model according to the specified
arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar
configuration to that of the Gemma 3n E4B, e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
Configuration objects that inherit from [`Gemma3nAudioConfig`] and can be used to control the model outputs. Read
the documentation from [`Gemma3nAudioConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 128):
Vocabulary size of the additional hard-token embeddings for audio model. These augment the embeddings
included in the `Gemma3nTextModel` to provide, e.g., the end of audio and audio soft token placeholder
tokens when converting `input_ids` to embeddings in the `Gemma3nForConditionalGeneration` model.
vocab_offset (`int`, *optional*, defaults to 262272):
Offset between the tokenizer vocab index for the token ids embedded by `Gemma3nMultimodalEmbedder` and the
0-indexed `Gemma3nMultimodalEmbedder.embedding` table.
input_feat_size (`int`, *optional*, defaults to 128):
The number of channels in each mel-spectrogram frame.
hidden_size (`int`, *optional*, defaults to 1536):
Dimension of the hidden representations.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
gradient_clipping (`float`, *optional*, defaults to 10000000000.0):
Clipping value used to stablize extremely large gradient values.
conf_attention_chunk_size (`int`, *optional*, defaults to 12):
The sub-sequence size for local attention processing inside the Conformer ("conf") section of the
Universal Speech Model.
conf_attention_context_left (`int`, *optional*, defaults to 13):
The left context size of the local attention inside the Conformer ("conf") section of the
Universal Speech Model.
conf_attention_context_right (`int`, *optional*, defaults to 0):
The right context size of the local attention inside the Conformer ("conf") section of the
Universal Speech Model.
conf_attention_logit_cap (`float`, *optional*, defaults to 50.0):
Logit cap applied during local attention inside the Conformer ("conf") section of the
Universal Speech Model.
conf_num_attention_heads (`int`, *optional*, defaults to 8):
The number of attention heads in local attention inside the Conformer ("conf") section of the
Universal Speech Model.
conf_num_hidden_layers (`int`, *optional*, defaults to 12):
The number of layers that use local attention inside the Conformer ("conf") section of the
Universal Speech Model.
conf_conv_kernel_size (`int`, *optional*, defaults to 5):
Convolution kernel size for the conformer block inside the Conformer ("conf") section of the
Universal Speech Model.
conf_reduction_factor (`int`, *optional*, defaults to 4):
Reduction factor used in the conformer block inside the Conformer ("conf") section of the
Universal Speech Model.
conf_residual_weight (`float`, *optional*, defaults to 0.5):
Residual connection weight inside the Conformer ("conf") section of the
Universal Speech Model.
sscp_conv_channel_size (`tuple(int, int)`, *optional*, defaults to `(128, 32)`):
The channel sizes for the first and second convolutional layers in the Sub-sample Convolution Projection
("sscp") section of the Universal Speech Model.
sscp_conv_group_norm_eps (`float`, *optional*, defaults to 0.001):
Epsilon used in group normalization in the subsample convolution projection in the Sub-sample Convolution
Projection ("sscp") section of the Universal Speech Model.
sscp_conv_kernel_size (`tuple(tuple(int, int), tuple(int, int))`, *optional*, defaults to `((3, 3), (3, 3))`):
Kernel sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample
Convolution Projection ("sscp") section of the Universal Speech Model. The kernel sizes are specified as a
tuple of height and width for each layer, where the height corresponds to the time dimension and the width
corresponds to the frequency dimension.
sscp_conv_stride_size (`tuple(tuple(int, int), tuple(int, int))`, *optional*, defaults to `((2, 2), (2, 2))`):
Stride sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample
Convolution Projection ("sscp") section of the Universal Speech Model. The stride sizes are specified as a
tuple of height and width for each layer, where the height corresponds to the time dimension and the width
corresponds to the frequency dimension.
Example:
```python
>>> from transformers import Gemma3nAudioConfig, Gemma3nAudioEncoder
>>> # Initializing a Gemma3nAudioEncoder gemma3n_audio-E4B-style configuration
>>> configuration = Gemma3nAudioConfig()
>>> # Initializing a model from the gemma3n_audio-E4B style configuration
>>> model = Gemma3nAudioEncoder(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "gemma3n_audio"
def __init__(
self,
vocab_size: int = 128,
vocab_offset: int = 262_144 + 128, # text vocab size + vision vocab size
input_feat_size: int = 128,
hidden_size: int = 1536,
rms_norm_eps: float = 1e-6,
gradient_clipping: float = 10_000_000_000.0,
conf_attention_chunk_size: int = 12,
conf_attention_context_left: int = 13,
conf_attention_context_right: int = 0,
conf_attention_logit_cap: float = 50.0,
conf_num_attention_heads: int = 8,
conf_num_hidden_layers: int = 12,
conf_conv_kernel_size: int = 5,
conf_reduction_factor: int = 4,
conf_residual_weight: float = 0.5,
sscp_conv_channel_size: tuple[int, int] = (128, 32),
sscp_conv_group_norm_eps: float = 1e-3,
sscp_conv_kernel_size: tuple[tuple[int, int], tuple[int, int]] = (
(3, 3),
(3, 3),
),
sscp_conv_stride_size: tuple[tuple[int, int], tuple[int, int]] = (
(2, 2),
(2, 2),
),
**kwargs,
):
super().__init__(**kwargs)
self.input_feat_size = input_feat_size
self.hidden_size = hidden_size
self.rms_norm_eps = rms_norm_eps
self.vocab_size = vocab_size
self.vocab_offset = vocab_offset
self.gradient_clipping = gradient_clipping
self.conf_attention_chunk_size = conf_attention_chunk_size
self.conf_attention_context_left = conf_attention_context_left
self.conf_attention_context_right = conf_attention_context_right
self.conf_attention_logit_cap = conf_attention_logit_cap
self.conf_num_attention_heads = conf_num_attention_heads
self.conf_num_hidden_layers = conf_num_hidden_layers
self.conf_conv_kernel_size = conf_conv_kernel_size
self.conf_reduction_factor = conf_reduction_factor
self.conf_residual_weight = conf_residual_weight
self.sscp_conv_channel_size = sscp_conv_channel_size
self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps
self.sscp_conv_kernel_size = sscp_conv_kernel_size
self.sscp_conv_stride_size = sscp_conv_stride_size
class Gemma3nVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration for a timm backbone [`TimmWrapper`]. It is used to
instantiate an timm model model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B
vision tower, e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
Configuration objects inherit from [`Gemma3nVisionConfig`] and can be used to control the model outputs. Read the
documentation from [`Gemma3nVisionConfig`] for more information.
Config loads imagenet label descriptions and stores them in `id2label` attribute, `label2id` attribute for default
imagenet models is set to `None` due to occlusions in the label descriptions.
Args:
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
do_pooling (`bool`, *optional*, defaults to `False`):
Whether to do pooling for the last_hidden_state in `TimmWrapper` or not.
architecture (`str`, *optional*, defaults to `"mobilenetv5_300m_enc"`):
Determines vision architecture for TimmWrapper.
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
vocab_size (`int`, *optional*, defaults to 128):
Vocabulary size of the additional hard-token embeddings for vision model.
vocab_offset (`int`, *optional*, defaults to 262144):
Offset between the tokenizer vocab index for the token ids embedded by `Gemma3nMultimodalEmbedder` and the
0-indexed `Gemma3nMultimodalEmbedder.embedding` table.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
Example:
```python
>>> from transformers import Gemma3nVisionConfig, TimmWrapper
>>> # Initializing a TimmWrapper gemma3n_vision-E4B-style configuration
>>> configuration = Gemma3nVisionConfig()
>>> # Initializing a gemma3n_vision-E4B-style TimmWrapper from the configuration
>>> model = TimmWrapper(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "gemma3n_vision"
def __init__(
self,
initializer_range: float = 0.02,
do_pooling: bool = False,
architecture: str = "mobilenetv5_300m_enc",
hidden_size: int = 2048,
vocab_size: int = 128,
vocab_offset: int = 262_144,
rms_norm_eps: float = 1e-06,
model_args: Optional[dict] = None,
**kwargs,
):
super().__init__(**kwargs)
self.initializer_range = initializer_range
self.do_pooling = do_pooling
self.model_args = model_args # named "model_args" for BC with timm
self.architecture = architecture
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.vocab_offset = vocab_offset
self.rms_norm_eps = rms_norm_eps
@classmethod
def from_dict(cls, config_dict: dict[str, Any], **kwargs):
label_names = config_dict.get("label_names", None)
is_custom_model = "num_labels" in kwargs or "id2label" in kwargs
# if no labels added to config, use imagenet labeller in timm
if label_names is None and not is_custom_model:
requires_backends(cls, ["timm"])
imagenet_subset = infer_imagenet_subset(config_dict)
if imagenet_subset:
dataset_info = ImageNetInfo(imagenet_subset)
synsets = dataset_info.label_names()
label_descriptions = dataset_info.label_descriptions(as_dict=True)
label_names = [label_descriptions[synset] for synset in synsets]
if label_names is not None and not is_custom_model:
kwargs["id2label"] = dict(enumerate(label_names))
# if all label names are unique, create label2id mapping as well
if len(set(label_names)) == len(label_names):
kwargs["label2id"] = {name: i for i, name in enumerate(label_names)}
else:
kwargs["label2id"] = None
# timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict.
# We are removing these attributes in order to have the native `transformers` num_labels attribute in config
# and to avoid duplicate attributes
num_labels_in_kwargs = kwargs.pop("num_labels", None)
num_labels_in_dict = config_dict.pop("num_classes", None)
# passed num_labels has priority over num_classes in config_dict
kwargs["num_labels"] = num_labels_in_kwargs or num_labels_in_dict
# pop num_classes from "pretrained_cfg",
# it is not necessary to have it, only root one is used in timm
if "pretrained_cfg" in config_dict and "num_classes" in config_dict["pretrained_cfg"]:
config_dict["pretrained_cfg"].pop("num_classes", None)
return super().from_dict(config_dict, **kwargs)
def to_dict(self) -> dict[str, Any]:
output = super().to_dict()
output["num_classes"] = self.num_labels
output["label_names"] = list(self.id2label.values())
output.pop("id2label", None)
output.pop("label2id", None)
return output
class Gemma3nConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Gemma3nForConditionalGeneration`]. It is used to
instantiate a Gemma3nForConditionalGeneration according to the specified arguments, defining the model
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
Gemma3n-E4B.
e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`Union[Gemma3nTextConfig, dict]`, *optional*):
The config object of the text backbone.
vision_config (`Union[AutoConfig, dict]`, *optional*):
Custom vision config or dict.
audio_config (`Union[AutoConfig, dict]`, *optional*):
Custom audio config or dict.
audio_soft_tokens_per_image (`int`, *optional*, defaults to 188):
The number of soft tokens per audio clip.
vision_soft_tokens_per_image (`int`, *optional*, defaults to 256):
The number of soft tokens per image.
boi_token_id (`int`, *optional*, defaults to 255999):
The begin-of-image token index to wrap the image prompt.
eoi_token_id (`int`, *optional*, defaults to 262144):
The end-of-image token index to wrap the image prompt.
image_token_id (`int`, *optional*, defaults to 262145):
The image token index to encode the image prompt.
boa_token_id (`int`, *optional*, defaults to 256000):
The begin-of-audio token index to wrap the audio prompt.
eoa_token_id (`int`, *optional*, defaults to 262272):
The end-of-audio token index to wrap the audio prompt.
audio_token_id (`int`, *optional*, defaults to 262273):
The audio token index to encode the audio prompt.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
Example:
```python
>>> from transformers import Gemma3nForConditionalGeneration, Gemma3nConfig, Gemma3nTextConfig
>>> # Initializing a MobileNet vision config, which is loaded from TIMM
>>> vision_config = Gemma3nVisionConfig()
>>> # Initializing a Gemma3n Audio config
>>> audio_config = Gemma3nAudioConfig()
>>> # Initializing a Gemma3n Text config
>>> text_config = Gemma3nTextConfig()
>>> # Initializing a Gemma3n gemma-3-4b style configuration
>>> configuration = Gemma3nConfig(text_config, vision_config, audio_config)
>>> # Initializing a model from the gemma-3-4b style configuration
>>> model = Gemma3nTextConfig(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "gemma3n"
sub_configs = {
"text_config": Gemma3nTextConfig,
"vision_config": Gemma3nVisionConfig,
"audio_config": Gemma3nAudioConfig,
}
def __init__(
self,
text_config: Optional[Union[Gemma3nTextConfig, dict[str, Any]]] = None,
vision_config: Optional[Union[Gemma3nVisionConfig, dict[str, Any]]] = None,
audio_config: Optional[Union[Gemma3nAudioConfig, dict[str, Any]]] = None,
audio_soft_tokens_per_image: int = 188,
vision_soft_tokens_per_image: int = 256,
boi_token_id: int = 255_999,
eoi_token_id: int = 262_144,
image_token_id: int = 262_145,
boa_token_id: int = 256_000,
eoa_token_id: int = 262_272,
audio_token_id: int = 262_273,
initializer_range: float = 0.02,
**kwargs,
):
super().__init__(**kwargs)
if isinstance(text_config, dict):
text_config = Gemma3nTextConfig(**text_config)
elif text_config is None:
text_config = Gemma3nTextConfig()
logger.info("text_config is None. Using default Gemma3nTextConfig.")
if isinstance(vision_config, dict):
vision_config = Gemma3nVisionConfig(**vision_config)
elif vision_config is None:
vision_config = Gemma3nVisionConfig()
logger.info("vision_config is None. Using default Gemma3nVisionConfig.")
if isinstance(audio_config, dict):
audio_config = Gemma3nAudioConfig(**audio_config)
elif audio_config is None:
audio_config = Gemma3nAudioConfig()
logger.info("audio_config is None. Using default Gemma3nAudioConfig.")
self.text_config = text_config
self.vision_config = vision_config
self.audio_config = audio_config
self.audio_soft_tokens_per_image = audio_soft_tokens_per_image
self.vision_soft_tokens_per_image = vision_soft_tokens_per_image
self.boi_token_id = boi_token_id
self.eoi_token_id = eoi_token_id
self.image_token_id = image_token_id
self.boa_token_id = boa_token_id
self.eoa_token_id = eoa_token_id
self.audio_token_id = audio_token_id
self.initializer_range = initializer_range
__all__ = ["Gemma3nAudioConfig", "Gemma3nConfig", "Gemma3nTextConfig", "Gemma3nVisionConfig"]

View File

@@ -0,0 +1,807 @@
# coding=utf-8
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Utility to convert Gemma models from Orbax to HF Transformers checkpoint.
python src/transformers/models/gemma3n/convert_gemma3n_weights.py \
--variant='gemma3n_e4b' \
--tokenizer_path="$HOME/nano3/checkpoints/tokenizer/gemma-3n-tokenizer.model" \
--checkpoint_path="$HOME/nano3/checkpoints/g251_orbax/" \
--output_path="$HOME/nano3/checkpoints/g251_vision_encoder/"
"""
import json
import os
import re
from collections.abc import Iterable, Mapping
from typing import Any
import accelerate
import numpy as np
import torch
import tree
from absl import app, flags, logging
from orbax import checkpoint as obc
from transformers import (
Gemma3nAudioConfig,
Gemma3nAudioFeatureExtractor,
Gemma3nConfig,
Gemma3nForConditionalGeneration,
Gemma3nProcessor,
Gemma3nTextConfig,
Gemma3nVisionConfig,
GemmaTokenizerFast,
GenerationConfig,
SiglipImageProcessorFast,
)
from transformers.image_utils import PILImageResampling
# ==== Internal Constants and Classes ====
_CHAT_TEMPLATE = """{{ bos_token }}
{%- if messages[0]['role'] == 'system' -%}
{%- if messages[0]['content'] is string -%}
{%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}
{%- else -%}
{%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}
{%- endif -%}
{%- set loop_messages = messages[1:] -%}
{%- else -%}
{%- set first_user_prefix = "" -%}
{%- set loop_messages = messages -%}
{%- endif -%}
{%- for message in loop_messages -%}
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
{{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
{%- endif -%}
{%- if (message['role'] == 'assistant') -%}
{%- set role = "model" -%}
{%- else -%}
{%- set role = message['role'] -%}
{%- endif -%}
{{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else "") }}
{%- if message['content'] is string -%}
{{ message['content'] | trim }}
{%- elif message['content'] is iterable -%}
{%- for item in message['content'] -%}
{%- if item['type'] == 'audio' -%}
{{ '<audio_soft_token>' }}
{%- elif item['type'] == 'image' -%}
{{ '<image_soft_token>' }}
{%- elif item['type'] == 'text' -%}
{{ item['text'] | trim }}
{%- endif -%}
{%- endfor -%}
{%- else -%}
{{ raise_exception("Invalid content type") }}
{%- endif -%}
{{ '<end_of_turn>\n' }}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{'<start_of_turn>model\n'}}
{%- endif -%}
"""
_DTYPES = {"float32", "bfloat16", "float16"}
_SLIDING_WINDOW_PATTERN = 5
_AUDIO_ENCODER_PARAMETER = "AudioEncoder/encoder"
_AUDIO_ENCODER_CONFORMER = f"{_AUDIO_ENCODER_PARAMETER}/conformer/stacked_layers"
_AUDIO_ENCODER_SSCP = f"{_AUDIO_ENCODER_PARAMETER}/feature"
_TRANSFORMER_PARAMETER = "transformer"
_TRANSFORMER_ALTUP_PROJ = f"{_TRANSFORMER_PARAMETER}/altup_projection_"
_TRANSFORMER_ALTUP_UNEMB = f"{_TRANSFORMER_PARAMETER}/altup_unembed_projection_"
_TRANSFORMER_DECODER_BLOCK = f"{_TRANSFORMER_PARAMETER}/stacked_layers/attention_type_"
_TRANSFORMER_DECODER_BLOCK_LEN = len(_TRANSFORMER_DECODER_BLOCK)
_TRANSFORMER_EMBEDDER = f"{_TRANSFORMER_PARAMETER}/embedder"
_TRANSFORMER_FINAL_NORM = "transformer/final_norm"
_TRANSFORMER_POST_TRAINING_PREFIX = "rlx_networks/policy_network/"
_TRANSFORMER_POST_TRAINING_PREFIX_LEN = len(_TRANSFORMER_POST_TRAINING_PREFIX)
# _MOBILE_NET_CONFIG = Gemma3nVisionConfig.from_pretrained("")
_MOBILE_NET_PREFIX = "mobilenet"
_MOBILE_NET_TIMM_SUMMED_BLOCK_SIZES = [3, 8, 45, 84]
_MOBILE_NET_CONV = "block_group_conv2d_"
_MOBILE_NET_FIB = "block_group_fused_ib_"
_MOBILE_NET_MQA = "block_group_mmqa_"
_MOBILE_NET_MSFA = "block_adapter_"
_MOBILE_NET_UIB = "block_group_uib_"
_MOBILE_NET_UIB_HAS_DW_START = {
(1, 0),
(1, 1),
(1, 2),
(1, 3),
(1, 4),
(2, 0),
(2, 1),
(2, 2),
(2, 3),
(2, 4),
(2, 5),
(2, 6),
(2, 7),
(3, 0),
}
_MOBILE_NET_UIB_HAS_DW_MID = {
(1, 0),
(2, 0),
(3, 0),
}
_VARIANT_GEMMA_3_2B = "gemma3n_e2b"
_VARIANT_GEMMA_3_4B = "gemma3n_e4b"
_VARIANTS: Mapping[str, Gemma3nConfig] = {
_VARIANT_GEMMA_3_2B: Gemma3nConfig(
text_config=Gemma3nTextConfig(
intermediate_size=2048 * 4,
num_hidden_layers=30,
activation_sparsity_pattern=(0.95,) * 10 + (0.0,) * 20,
num_kv_shared_layers=10,
),
vision_config=Gemma3nVisionConfig(),
audio_config=Gemma3nAudioConfig(),
),
_VARIANT_GEMMA_3_4B: Gemma3nConfig(
text_config=Gemma3nTextConfig(),
vision_config=Gemma3nVisionConfig(),
audio_config=Gemma3nAudioConfig(),
),
}
# ==== Flags ====
_AUDIO_DTYPE = flags.DEFINE_enum(
name="audio_dtype",
default="bfloat16",
help="The floating point precision (aka dtype) of the model.",
enum_values=_DTYPES,
)
_CHECKPOINT_PATH = flags.DEFINE_string(
name="checkpoint_path",
default=None,
help="Path to the Orbax checkpoint.",
required=True,
)
_INCLUDE_CHAT_TEMPLATE = flags.DEFINE_bool(
name="include_chat_template", default=False, help="If true, will save the default chat template with the tokenizer"
)
_OUTPUT_PATH = flags.DEFINE_string(
name="output_path",
default=None,
help="Path to store the HF checkpoint.",
required=True,
)
_TRANSFORMER_DTYPE = flags.DEFINE_enum(
name="text_dtype",
default="bfloat16",
help="The floating point precision (aka dtype) of the model.",
enum_values=_DTYPES,
)
_TOKENIZER_PATH = flags.DEFINE_string(
name="tokenizer_path",
default=None,
help="Path to the SentencePiece model file.",
required=True,
)
_VARIANT = flags.DEFINE_enum(
name="variant",
default=_VARIANT_GEMMA_3_4B,
help="The model variant to convert.",
enum_values=set(_VARIANTS.keys()),
)
_VERBOSE = flags.DEFINE_bool(
name="verbose",
default=False,
help="If true, log the path, shape, and dtype of every converted layer.",
)
_VISION_DTYPE = flags.DEFINE_enum(
name="vision_dtype",
default="bfloat16",
help="The floating point precision (aka dtype) of the model.",
enum_values=_DTYPES,
)
def convert_audio_encoder_weights(
config: Gemma3nAudioConfig,
path: str,
param: str,
weights: np.ndarray,
) -> Iterable[tuple[str, np.ndarray]]:
converted_paths: list[str] = []
converted_weights: list[Any] = []
if path.startswith(_AUDIO_ENCODER_CONFORMER):
assert weights.shape[0] == config.conf_num_hidden_layers
for i, matrix in enumerate(weights):
if "fflayer_end" in path:
base = f"conformer.{i}.ffw_layer_end"
if path.endswith("ffn_layer1"):
converted_paths.append(f"{base}.ffw_layer_1.weight")
converted_weights.append(matrix.transpose())
elif path.endswith("ffn_layer2"):
converted_paths.append(f"{base}.ffw_layer_2.weight")
converted_weights.append(matrix.transpose())
elif path.endswith("post_layer_norm"):
converted_paths.append(f"{base}.post_layer_norm.weight")
converted_weights.append(matrix)
elif path.endswith("pre_layer_norm"):
converted_paths.append(f"{base}.pre_layer_norm.weight")
converted_weights.append(matrix)
elif "fflayer_start" in path:
base = f"conformer.{i}.ffw_layer_start"
if path.endswith("ffn_layer1"):
converted_paths.append(f"{base}.ffw_layer_1.weight")
converted_weights.append(matrix.transpose())
elif path.endswith("ffn_layer2"):
converted_paths.append(f"{base}.ffw_layer_2.weight")
converted_weights.append(matrix.transpose())
elif path.endswith("post_layer_norm"):
converted_paths.append(f"{base}.post_layer_norm.weight")
converted_weights.append(matrix)
elif path.endswith("pre_layer_norm"):
converted_paths.append(f"{base}.pre_layer_norm.weight")
converted_weights.append(matrix)
elif path.endswith("final_ln"):
converted_paths.append(f"conformer.{i}.norm.weight")
converted_weights.append(matrix)
elif "lconv" in path:
base = f"conformer.{i}.lconv1d"
if path.endswith("conv_norm"):
converted_paths.append(f"{base}.conv_norm.weight")
converted_weights.append(matrix)
elif path.endswith("depthwise_conv1d"):
converted_paths.append(f"{base}.depthwise_conv1d.weight")
converted_weights.append(matrix.transpose())
elif path.endswith("linear_end"):
converted_paths.append(f"{base}.linear_end.weight")
converted_weights.append(matrix.transpose())
elif path.endswith("linear_start"):
converted_paths.append(f"{base}.linear_start.weight")
converted_weights.append(matrix.transpose())
elif path.endswith("ln"):
converted_paths.append(f"{base}.pre_layer_norm.weight")
converted_weights.append(matrix)
elif "trans_atten" in path:
base = f"conformer.{i}.attention"
if param == "per_dim_scale":
converted_paths.append(f"{base}.attn.per_dim_scale")
converted_weights.append(matrix)
if path.endswith("query_key_value_projection"):
converted_paths.extend(
[f"{base}.attn.q_proj.weight", f"{base}.attn.k_proj.weight", f"{base}.attn.v_proj.weight"]
)
converted_weights.extend(
[
m.reshape(config.hidden_size, config.hidden_size).transpose()
for m in matrix.transpose(1, 0, 2, 3)
]
)
elif path.endswith("pos_proj"):
converted_paths.append(f"{base}.attn.relative_position_embedding.pos_proj.weight")
converted_weights.append(matrix.reshape(config.hidden_size, config.hidden_size).transpose())
elif path.endswith("post"):
converted_paths.append(f"{base}.post.weight")
converted_weights.append(matrix.transpose(2, 0, 1).reshape(config.hidden_size, config.hidden_size))
elif path.endswith("post_norm"):
converted_paths.append(f"{base}.post_norm.weight")
converted_weights.append(matrix)
elif path.endswith("pre_norm"):
converted_paths.append(f"{base}.pre_attn_norm.weight")
converted_weights.append(matrix)
elif path.startswith(_AUDIO_ENCODER_SSCP):
if path.endswith("input_proj"):
converted_paths.append("subsample_conv_projection.input_proj_linear.weight")
converted_weights.append(
weights.transpose(2, 0, 1).reshape(config.hidden_size, config.sscp_conv_channel_size[1] ** 2)
)
elif "norm_" in path:
index = int(path[-1])
converted_paths.append(f"subsample_conv_projection.conv_{index}.norm.weight")
converted_weights.append(weights)
elif "subsampling_" in path:
index = int(path[-1])
converted_paths.append(f"subsample_conv_projection.conv_{index}.conv.weight")
converted_weights.append(weights.transpose(3, 2, 0, 1))
if (cpl := len(converted_paths)) != (cwl := len(converted_weights)):
raise ValueError(
"The `converted_paths` and `converted_weights` should be the same "
f"length. Got {cpl} and {cwl}, respectively, for {path}."
)
return zip(converted_paths, converted_weights)
def convert_transformer_weights(
config: Gemma3nTextConfig,
path: str,
param: str,
weights: np.ndarray,
) -> Iterable[tuple[str, np.ndarray]]:
if path.startswith(_TRANSFORMER_POST_TRAINING_PREFIX):
path = path[_TRANSFORMER_POST_TRAINING_PREFIX_LEN:]
converted_paths: list[str] = []
converted_weights: list[Any] = []
if path.startswith(_TRANSFORMER_ALTUP_PROJ):
index = int(path[-1])
converted_paths.append(f"altup_projections.{index}.weight")
converted_weights.append(weights.transpose())
elif path.startswith(_TRANSFORMER_ALTUP_UNEMB):
index = int(path[-1])
converted_paths.append(f"altup_unembed_projections.{index}.weight")
converted_weights.append(weights.transpose())
elif path.startswith(_TRANSFORMER_DECODER_BLOCK):
attention_type_index = int(path[_TRANSFORMER_DECODER_BLOCK_LEN])
assert weights.shape[0] == config.num_hidden_layers / _SLIDING_WINDOW_PATTERN
for i, matrix in enumerate(weights):
layer_idx = _SLIDING_WINDOW_PATTERN * i + attention_type_index
base_path = f"layers.{layer_idx}"
if "altup" in path:
altup_path = f"{base_path}.altup"
if param == "correct_output_scale":
converted_paths.append(f"{altup_path}.correct_output_scale")
converted_weights.append(matrix)
elif param == "correction_coefs":
converted_paths.append(f"{altup_path}.correction_coefs.weight")
converted_weights.append(matrix.transpose())
elif param == "prediction_coefs":
converted_paths.append(f"{altup_path}.prediction_coefs.weight")
converted_weights.append(
np.clip(
matrix.reshape(config.altup_num_inputs, config.altup_num_inputs**2).transpose(),
-config.altup_coef_clip,
config.altup_coef_clip,
)
)
if path.endswith("modality_router"):
converted_paths.append(f"{altup_path}.modality_router.weight")
converted_weights.append(matrix.transpose())
elif path.endswith("router_norm_layer"):
converted_paths.append(f"{altup_path}.router_norm.weight")
converted_weights.append(matrix)
elif path.endswith("attn/attn_vec_einsum"):
converted_paths.append(f"{base_path}.self_attn.o_proj.weight")
converted_weights.append(
matrix.transpose(2, 0, 1).reshape(config.hidden_size, config.num_attention_heads * config.head_dim)
)
elif path.endswith("attn/kv_einsum"):
converted_paths.extend(
[
f"{base_path}.self_attn.k_proj.weight",
f"{base_path}.self_attn.v_proj.weight",
]
)
k_proj_weights, v_proj_weights = matrix.transpose(0, 2, 1, 3)
kv_proj_shape = (config.hidden_size, config.num_key_value_heads * config.head_dim)
converted_weights.extend(
[
k_proj_weights.reshape(kv_proj_shape).transpose(),
v_proj_weights.reshape(kv_proj_shape).transpose(),
]
)
elif path.endswith("attn/q_einsum"):
converted_paths.append(f"{base_path}.self_attn.q_proj.weight")
converted_weights.append(
matrix.transpose(1, 0, 2)
.reshape(config.hidden_size, config.num_attention_heads * config.head_dim)
.transpose()
)
elif path.endswith("attn/query_norm"):
converted_paths.append(f"{base_path}.self_attn.q_norm.weight")
converted_weights.append(matrix)
elif path.endswith("attn/key_norm"):
converted_paths.append(f"{base_path}.self_attn.k_norm.weight")
converted_weights.append(matrix)
elif path.endswith("laurel_block/linear_left"):
converted_paths.append(f"{base_path}.laurel.linear_left.weight")
converted_weights.append(matrix.transpose())
elif path.endswith("laurel_block/linear_right"):
converted_paths.append(f"{base_path}.laurel.linear_right.weight")
converted_weights.append(matrix.transpose())
elif path.endswith("mlp/gating_einsum"):
converted_paths.extend([f"{base_path}.mlp.gate_proj.weight", f"{base_path}.mlp.up_proj.weight"])
gate_proj_weight, up_proj_weight = matrix
converted_weights.extend([gate_proj_weight, up_proj_weight])
elif path.endswith("mlp/linear"):
converted_paths.append(f"{base_path}.mlp.down_proj.weight")
converted_weights.append(matrix.transpose())
elif path.endswith("per_layer_input_gate"):
converted_paths.append(f"{base_path}.per_layer_input_gate.weight")
converted_weights.append(matrix.transpose())
elif path.endswith("per_layer_projection"):
converted_paths.append(f"{base_path}.per_layer_projection.weight")
converted_weights.append(matrix.transpose())
elif path.endswith("post_attention_norm"):
converted_paths.append(f"{base_path}.post_attention_layernorm.weight")
converted_weights.append(matrix)
elif path.endswith("post_ffw_norm"):
converted_paths.append(f"{base_path}.post_feedforward_layernorm.weight")
converted_weights.append(matrix)
elif path.endswith("post_laurel_norm"):
converted_paths.append(f"{base_path}.laurel.post_laurel_norm.weight")
converted_weights.append(matrix)
elif path.endswith("post_per_layer_input_norm"):
converted_paths.append(f"{base_path}.post_per_layer_input_norm.weight")
converted_weights.append(matrix)
elif path.endswith("pre_attention_norm"):
converted_paths.append(f"{base_path}.input_layernorm.weight")
converted_weights.append(matrix)
elif path.endswith("pre_ffw_norm"):
converted_paths.append(f"{base_path}.pre_feedforward_layernorm.weight")
converted_weights.append(matrix)
elif path == _TRANSFORMER_EMBEDDER:
if param == "input_embedding":
converted_paths.append("embed_tokens.weight")
# Gemma 3n model doesn't have soft tokens or "end of" tokens for images and audio in its input and output
# embeddings, so we resize to avoid bugs observed with Mllama
pre_expansion_embeddings = weights
pad_token_slice = slice(config.pad_token_id, config.pad_token_id + 1)
new_embeddings = np.repeat(pre_expansion_embeddings[pad_token_slice], 256, axis=0)
weights = np.vstack([pre_expansion_embeddings, new_embeddings])
converted_weights.append(weights)
elif param == "per_layer_embeddings":
converted_paths.append("embed_tokens_per_layer.weight")
converted_weights.append(
weights.reshape(
config.vocab_size_per_layer_input, config.num_hidden_layers * config.hidden_size_per_layer_input
)
)
elif path.startswith(_TRANSFORMER_EMBEDDER):
# TODO: ryanmullins - support multimodal norms and projections
if path.endswith("per_layer_model_projection"):
converted_paths.append("per_layer_model_projection.weight")
converted_weights.append(
weights.reshape(
config.hidden_size, config.num_hidden_layers * config.hidden_size_per_layer_input
).transpose()
)
elif path.endswith("per_layer_projection_norm"):
converted_paths.append("per_layer_projection_norm.weight")
converted_weights.append(weights)
elif path == _TRANSFORMER_FINAL_NORM:
converted_paths = ["norm.weight"]
converted_weights = [weights]
if (cpl := len(converted_paths)) != (cwl := len(converted_weights)):
raise ValueError(
"The `converted_paths` and `converted_weights` should be the same "
f"length. Got {cpl} and {cwl}, respectively, for {path}."
)
return zip(converted_paths, converted_weights)
def convert_vision_weights(
config: Gemma3nVisionConfig,
path: str,
param: str,
weights: np.ndarray,
) -> Iterable[tuple[str, np.ndarray]]:
def generate_base_path(path: str, block_type: str) -> tuple[str, tuple[int, int]]:
re_str = r"{}(\d+)/".format(block_type)
re_pattern = re.compile(re_str)
match = re.search(re_pattern, path).group(1)
idx = abs(int(match)) - 1
for block_idx, v in enumerate(_MOBILE_NET_TIMM_SUMMED_BLOCK_SIZES):
if v > idx:
offset = _MOBILE_NET_TIMM_SUMMED_BLOCK_SIZES[block_idx - 1] if block_idx > 0 else 0
layer_idx = idx - offset
return f"blocks.{block_idx}.{layer_idx}", (block_idx, layer_idx)
raise ValueError(f"could not extract a base path from {path}")
if _MOBILE_NET_MSFA in path:
converted_path = "msfa"
if "ffn/Normalize_0" in path:
converted_path += ".ffn.pw_exp.bn.weight"
converted_weight = weights
elif "ffn/Normalize_1" in path:
converted_path += ".ffn.pw_proj.bn.weight"
converted_weight = weights
elif "ffn/expand" in path:
converted_path += ".ffn.pw_exp.conv.weight"
converted_weight = weights.transpose()[:, :, None, None]
elif "ffn/project" in path:
converted_path += ".ffn.pw_proj.conv.weight"
converted_weight = weights.transpose()[:, :, None, None]
elif "Normalize_0" in path:
converted_path += ".norm.weight"
converted_weight = weights
elif _MOBILE_NET_CONV in path:
if "Conv_0" in path:
converted_path = "conv_stem.conv.weight"
converted_weight = weights.transpose(3, 2, 1, 0)
elif "Normalize_0" in path:
converted_path = "conv_stem.bn.weight"
converted_weight = weights
elif _MOBILE_NET_FIB in path:
converted_path, _ = generate_base_path(path, _MOBILE_NET_FIB)
if "Normalize_0" in path:
converted_path += ".bn1.weight"
converted_weight = weights
elif "Normalize_1" in path:
converted_path += ".bn2.weight"
converted_weight = weights
elif "expand_conv" in path:
converted_path += ".conv_exp.weight"
converted_weight = weights.transpose(3, 2, 1, 0)
else:
converted_path += ".conv_pwl.weight"
converted_weight = weights.transpose()[:, :, None, None]
elif _MOBILE_NET_MQA in path:
converted_path, _ = generate_base_path(path, _MOBILE_NET_MQA)
if "LayerScale_0" in path:
converted_path += ".layer_scale.gamma"
converted_weight = weights
elif "Normalize_0" in path:
converted_path += ".norm.weight"
converted_weight = weights
elif "Normalize_1" in path:
converted_path += ".attn.key.norm.weight"
converted_weight = weights
elif "Normalize_2" in path:
converted_path += ".attn.value.norm.weight"
converted_weight = weights
elif "key_dwconv" in path:
converted_path += ".attn.key.down_conv.weight"
converted_weight = weights.transpose()
elif "key_proj" in path:
converted_path += ".attn.key.proj.weight"
converted_weight = weights.transpose()[:, :, None, None]
elif "output_proj" in path:
converted_path += ".attn.output.proj.weight"
converted_weight = weights.transpose()[:, :, None, None]
elif "query_proj" in path:
converted_path += ".attn.query.proj.weight"
converted_weight = weights.transpose()[:, :, None, None]
elif "value_dwconv" in path:
converted_path += ".attn.value.down_conv.weight"
converted_weight = weights.transpose()
elif "value_proj" in path:
converted_path += ".attn.value.proj.weight"
converted_weight = weights.transpose()[:, :, None, None]
elif _MOBILE_NET_UIB in path:
converted_path, idx_key = generate_base_path(path, _MOBILE_NET_UIB)
has_dw_start = idx_key in _MOBILE_NET_UIB_HAS_DW_START
has_dw_mid = idx_key in _MOBILE_NET_UIB_HAS_DW_MID
if "LayerScale_0" in path:
converted_path += ".layer_scale.gamma"
converted_weight = weights
elif "Normalize_0" in path:
converted_path += ".dw_start.bn.weight" if has_dw_start else ".pw_exp.bn.weight"
converted_weight = weights
elif "Normalize_1" in path:
converted_path += ".pw_exp.bn.weight" if has_dw_start else ".pw_proj.bn.weight"
converted_weight = weights
elif "Normalize_2" in path:
converted_path += ".dw_mid.bn.weight" if has_dw_mid else ".pw_proj.bn.weight"
converted_weight = weights
elif "Normalize_3" in path:
converted_path += ".pw_proj.bn.weight"
converted_weight = weights
elif "expand" in path:
converted_path += ".pw_exp.conv.weight"
converted_weight = weights.transpose()[:, :, None, None]
elif "middle_dwconv" in path:
converted_path += ".dw_mid.conv.weight"
converted_weight = weights.transpose(3, 2, 1, 0)
elif "project" in path:
converted_path += ".pw_proj.conv.weight"
converted_weight = weights.transpose()[:, :, None, None]
elif "start_dwconv" in path:
converted_path += ".dw_start.conv.weight"
converted_weight = weights.transpose(3, 2, 1, 0)
return [(converted_path, converted_weight)]
def convert(checkpoint_path: str, config: Gemma3nConfig) -> dict[str, torch.Tensor]:
"""Loads Orbax checkpoint from `input_path` and converts it to HF tree."""
checkpointer = obc.PyTreeCheckpointer()
ckpt = checkpointer.restore(checkpoint_path)
hf_tree: dict[str, torch.Tensor] = {}
def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> None:
hf_tree[path] = torch.from_numpy(weights.astype("float32")).type(target_dtype)
if _VERBOSE.value:
logging.info(
"%s converted shape=%s with dtype=%s",
path,
weights.shape,
target_dtype,
)
for (path, param), value in tree.flatten_with_path(ckpt):
if param == "audio_input_embedding_extra":
update_tree("model.embed_audio.embedding.weight", value, config.audio_config.torch_dtype)
elif path.endswith("audio_embedding_norm"):
update_tree("model.embed_audio.hard_embedding_norm.weight", value, config.audio_config.torch_dtype)
elif path.endswith("audio_input_projection"):
update_tree(
"model.embed_audio.embedding_projection.weight", value.transpose(), config.audio_config.torch_dtype
)
elif path.endswith("audio_soft_embedding_norm"):
update_tree("model.embed_audio.soft_embedding_norm.weight", value, config.audio_config.torch_dtype)
elif param == "mm_input_embedding_extra":
update_tree("model.embed_vision.embedding.weight", value, config.vision_config.torch_dtype)
elif path.endswith("mm_hard_embedding_norm"):
update_tree("model.embed_vision.hard_embedding_norm.weight", value, config.vision_config.torch_dtype)
elif path.endswith("mm_input_projection"):
update_tree(
"model.embed_vision.embedding_projection.weight", value.transpose(), config.vision_config.torch_dtype
)
elif path.endswith("mm_soft_embedding_norm"):
update_tree("model.embed_vision.soft_embedding_norm.weight", value, config.vision_config.torch_dtype)
elif path.startswith(_TRANSFORMER_PARAMETER):
for path, weights in convert_transformer_weights(config.text_config, path, param, value):
update_tree(f"model.language_model.{path}", weights, config.text_config.torch_dtype)
elif _MOBILE_NET_PREFIX in path:
mobilenet_prefix_idx = path.index(_MOBILE_NET_PREFIX)
path = path[mobilenet_prefix_idx:]
for path, weights in convert_vision_weights(config.vision_config, path, param, value):
update_tree(f"model.vision_tower.timm_model.{path}", weights, config.vision_config.torch_dtype)
elif path.startswith(_AUDIO_ENCODER_PARAMETER):
for path, weights in convert_audio_encoder_weights(config.audio_config, path, param, value):
update_tree(f"model.audio_tower.{path}", weights, config.audio_config.torch_dtype)
hf_tree["lm_head.weight"] = hf_tree["model.language_model.embed_tokens.weight"]
return hf_tree
def main(*args):
del args
output_path = _OUTPUT_PATH.value
variant = _VARIANT.value
config = _VARIANTS[variant]
config.audio_config.torch_dtype = getattr(torch, _AUDIO_DTYPE.value)
config.text_config.torch_dtype = getattr(torch, _TRANSFORMER_DTYPE.value)
config.vision_config.torch_dtype = getattr(torch, _VISION_DTYPE.value)
if _INCLUDE_CHAT_TEMPLATE.value:
# Chat template is included for instruction tuned models, which treat
# both "<eos>" and "<end_of_turn>" as generation stoppers.
config.eos_token_id = [1, 106]
logging.info(
"Converting Gemma 3 (%s) @ %s (language) and %s (vision)",
variant,
_TRANSFORMER_DTYPE.value,
_VISION_DTYPE.value,
)
state_tree = convert(_CHECKPOINT_PATH.value, config)
logging.info("Converted Gemma 3 (%s) state tree from Orbax to Hugging Face.", variant)
with accelerate.init_empty_weights():
model = Gemma3nForConditionalGeneration(config=config)
model.load_state_dict(state_tree, assign=True, strict=True)
logging.info(
"Loaded Gemma 3 (%s) in Hugging Face Transformers as a %s instance.",
variant,
type(model).__name__,
)
model.save_pretrained(output_path, state_dict=state_tree, safe_serialization=True)
logging.info(
"Saved Gemma 3 (%s) to SafeTensors in %s using %s",
variant,
output_path,
type(model).__name__,
)
del model
del state_tree
chat_template_kwargs = {"chat_template": _CHAT_TEMPLATE} if _INCLUDE_CHAT_TEMPLATE.value else {}
tokenizer = GemmaTokenizerFast(
_TOKENIZER_PATH.value,
add_bos_token=True,
extra_special_tokens={
"image_token": "<image_soft_token>", # Should be ID=262_145
"boi_token": "<start_of_image>", # Should be ID=255_999
"eoi_token": "<end_of_image>", # Should be ID=262_144
"audio_token": "<audio_soft_token>", # Should be ID=262_273
"boa_token": "<start_of_audio>", # Should be ID=256_000
"eoa_token": "<end_of_audio>", # Should be ID=262_272
},
**chat_template_kwargs,
)
tokenizer.save_pretrained(output_path)
logging.info("Saved GemmaTokenizer for %s to %s", variant, output_path)
feature_extractor = Gemma3nAudioFeatureExtractor()
image_processor = SiglipImageProcessorFast(
image_seq_length=256,
image_mean=(0.5,) * 3,
image_std=(0.5,) * 3,
size={"height": 768, "width": 768},
resample=PILImageResampling.BILINEAR,
do_normalize=False,
)
processor = Gemma3nProcessor(
feature_extractor=feature_extractor,
image_processor=image_processor,
tokenizer=tokenizer,
**chat_template_kwargs,
)
processor.save_pretrained(output_path)
logging.info("Saved Gemma3nProcessor for %s to %s", variant, output_path)
# NOTE: feature_extractor and image_processor both use the same filename, preprocessor_config.json, when saved to
# disk, but the files are overwritten by processor.save_pretrained(). However, the configs can be unioned, saved,
# and loaded from the same preprocessor_config.json file, so we do that explicitly here.
feature_extractor_config = json.loads(feature_extractor.to_json_string())
image_processor_config = json.loads(image_processor.to_json_string())
preprocessor_config = {**feature_extractor_config, **image_processor_config}
with open(os.path.join(output_path, "preprocessor_config.json"), "w", encoding="utf-8") as writer:
writer.write(json.dumps(preprocessor_config, indent=2, sort_keys=True) + "\n")
logging.info("Saved joint preprocessor_config.json for %s to %s", variant, output_path)
del feature_extractor, image_processor, processor, tokenizer
generation_config = GenerationConfig(
pad_token_id=config.text_config.pad_token_id,
bos_token_id=config.text_config.bos_token_id,
eos_token_id=(
[config.text_config.eos_token_id, 106] if _INCLUDE_CHAT_TEMPLATE.value else config.text_config.eos_token_id
),
cache_implementation="hybrid",
temperature=1.0,
do_sample=True,
top_k=64,
top_p=0.95,
)
generation_config.save_pretrained(output_path)
if __name__ == "__main__":
app.run(main)

View File

@@ -0,0 +1,338 @@
# coding=utf-8
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Sequence
from typing import Optional, Union
import numpy as np
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import PaddingStrategy, TensorType, logging
logger = logging.get_logger(__name__)
def create_fb_matrix(
n_freqs: int,
f_min: float,
f_max: float,
n_mels: int,
sample_rate: int,
fft_length: int,
norm: Optional[str] = None,
) -> np.ndarray:
r"""Create a frequency bin conversion matrix (NumPy version).
Args:
n_freqs (int): Number of frequencies to highlight/apply
f_min (float): Minimum frequency (Hz)
f_max (float): Maximum frequency (Hz)
n_mels (int): Number of mel filterbanks
sample_rate (int): Sample rate of the audio waveform
fft_length (int): FFT length
norm (Optional[str]): If 'slaney', divide the triangular mel weights by
the width of the mel band (area normalization). (Default: ``None``)
Returns:
np.ndarray: Triangular filter banks (fb matrix) of size (``n_freqs``,
``n_mels``)
meaning number of frequencies to highlight/apply to x the number of
filterbanks.
Each column is a filterbank so that assuming there is a matrix A of
size (..., ``n_freqs``), the applied result would be
``A @ create_fb_matrix_numpy(A.shape[-1], ...)``.
"""
if norm is not None and norm != "slaney":
raise ValueError("norm must be one of None or 'slaney'")
# freq bins
all_freqs = np.arange(n_freqs, dtype=np.float32) * (sample_rate / fft_length)
# calculate mel freq bins
# hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0))
m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0))
m_pts = np.linspace(m_min, m_max, n_mels + 2)
# mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.)
f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0)
# calculate difference between each mel point and each stft freq point in Hz
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
slopes = np.expand_dims(f_pts, 0) - np.expand_dims(all_freqs, 1) # (n_freqs, n_mels + 2)
# create overlapping triangles
zero = np.zeros(1, dtype=np.float32)
down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels)
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels)
fb = np.maximum(zero, np.minimum(down_slopes, up_slopes))
if norm is not None and norm == "slaney":
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels])
fb *= np.expand_dims(enorm, 0)
return fb
def _unfold(array: np.ndarray, dimension: int, size: int, step: int) -> np.ndarray:
"""A basic NumPy equivalent of PyTorch's unfold for 2D arrays along the last dim."""
if array.ndim != 2:
raise ValueError("This unfold implementation currently supports 2D arrays (batch, time).")
if dimension != -1 and dimension != array.ndim - 1:
raise ValueError("This unfold implementation only supports unfolding the last dimension.")
batch_size, original_length = array.shape
num_frames = (original_length - size) // step + 1
if num_frames <= 0:
return np.zeros((batch_size, 0, size), dtype=array.dtype)
output_shape = (batch_size, num_frames, size)
output_strides = (array.strides[0], array.strides[1] * step, array.strides[1])
return np.lib.stride_tricks.as_strided(array, shape=output_shape, strides=output_strides)
class Gemma3nAudioFeatureExtractor(SequenceFeatureExtractor):
"""An audio feature extractor Universal Speech Models https://arxiv.org/abs/2303.01037.
Args:
feature_size (`int`, *optional*, defaults to 128):
The feature dimension of the extracted features.
sampling_rate (`int`, *optional*, defaults to 16000):
The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
padding_value (`float`, *optional*, defaults to 0.0):
Padding value used to pad the audio. Should correspond to silences.
return_attention_mask (`bool`, *optional*, defaults to `True`):
Whether to return the attention mask for the generated MEL spectrograms.
frame_length_ms (`float`, *optional*, defaults to 32.0):
The length of a frame in milliseconds.
hop_length_ms (`float`, *optional*, defaults to 10.0):
Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients.
min_frequency (`float`, *optional*, defaults to 125.0):
The minimum frequency (in Hz) for the Mel filterbank.
max_frequency (`float`, *optional*, defaults to 7600.0):
The maximum frequency (in Hz) for the Mel filterbank.
preemphasis (`float`, *optional*, defaults to 0.97):
The preemphasis coefficient.
preemphasis_htk_flavor (`bool`, *optional*, defaults to `True`):
Whether to use HTK-style preemphasis.
fft_overdrive (`bool`, *optional*, defaults to `True`):
Whether to use FFT overdrive.
dither (`float`, *optional*, defaults to 0.0):
Adds dithering. In other words, adds a small Gaussian noise to each frame.
E.g. use 0.0001 to add dithering with a normal distribution centered
around 0.0 with standard deviation 0.0001 (assuming [-1,+1] range of raw_speech).
The value 0.0 means no dithering.
Dithering has similar effect as `spectrogram(mel_floor=...)`. It reduces
the high log_mel_fbank values for signals with hard-zero sections,
when VAD cutoff is present in the signal.
input_scale_factor (`float`, *optional*, defaults to 1.0):
Scaling factor applied to the input waveform.
mel_floor (`float`, *optional*, defaults to 1e-05):
Minimum value for Mel spectrograms to avoid log(0).
per_bin_mean (`Optional[Sequence[float]]`, *optional*):
Mean values for per-bin normalization.
per_bin_stddev (`Optional[Sequence[float]]`, *optional*):
Standard deviation values for per-bin normalization.
"""
model_input_names = ["input_features", "input_features_mask"]
def __init__(
self,
feature_size: int = 128,
sampling_rate: int = 16_000,
padding_value: float = 0.0,
return_attention_mask: bool = True,
frame_length_ms: float = 32.0,
hop_length_ms: float = 10.0,
min_frequency: float = 125.0,
max_frequency: float = 7600.0,
preemphasis: float = 0.97,
preemphasis_htk_flavor: bool = True,
fft_overdrive: bool = True,
dither: float = 0.0,
input_scale_factor: float = 1.0,
mel_floor: float = 1e-5,
per_bin_mean: Optional[Sequence[float]] = None,
per_bin_stddev: Optional[Sequence[float]] = None,
**kwargs,
):
super().__init__(
feature_size=feature_size,
sampling_rate=sampling_rate,
padding_value=padding_value,
return_attention_mask=return_attention_mask,
**kwargs,
)
self.min_frequency = min_frequency
self.max_frequency = max_frequency
self.preemphasis = preemphasis
self.preemphasis_htk_flavor = preemphasis_htk_flavor
self.fft_overdrive = fft_overdrive
self.dither = dither
self.input_scale_factor = input_scale_factor
self.frame_length = int(round(sampling_rate * frame_length_ms / 1000.0))
self.hop_length = int(round(sampling_rate * hop_length_ms / 1000.0))
self.mel_floor = np.array(mel_floor, dtype=np.float64)
fft_length = 2 ** math.ceil(math.log2(self.frame_length))
if self.fft_overdrive:
fft_length *= 2
self.fft_length = fft_length
hann_arange = np.arange(self.frame_length, dtype=np.float32)
window = 0.5 * (1 - np.cos(2 * np.pi * hann_arange / self.frame_length))
self.window = window.astype(np.float32)
self.mel_filters = create_fb_matrix(
n_freqs=self.fft_length // 2 + 1,
f_min=min_frequency,
f_max=max_frequency,
n_mels=feature_size,
sample_rate=self.sampling_rate,
norm=None,
fft_length=fft_length,
)
if per_bin_mean is not None:
self.per_bin_mean = np.array(per_bin_mean).reshape(1, 1, feature_size)
else:
self.per_bin_mean = None
if per_bin_stddev is not None:
self.per_bin_stddev = np.array(per_bin_stddev).reshape(1, 1, feature_size)
else:
self.per_bin_stddev = None
def _extract_spectrogram(self, waveform: np.ndarray, attention_mask: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
""""""
if waveform.ndim == 1: # If single waveform, add batch dimension
waveform = np.expand_dims(waveform, axis=0)
if self.dither > 0.0:
waveform = waveform + self.dither * np.random.randn(*waveform.shape).astype(waveform.dtype)
if self.input_scale_factor != 1.0:
waveform = waveform * self.input_scale_factor
frame_size_for_unfold = self.frame_length + 1
# NumPy equivalent of unfold for [B, NumFrames, frame_size_for_unfold]
frames_to_process = _unfold(waveform, dimension=-1, size=frame_size_for_unfold, step=self.hop_length)
if self.preemphasis > 0.0:
if self.preemphasis_htk_flavor:
first_in_frame = frames_to_process[..., :1] * (1.0 - self.preemphasis)
rest_in_frame = frames_to_process[..., 1:-1] - self.preemphasis * frames_to_process[..., :-2]
frames = np.concatenate([first_in_frame, rest_in_frame], axis=-1)
else:
frames = frames_to_process[..., 1:] - self.preemphasis * frames_to_process[..., :-1]
else:
frames = frames_to_process[..., :-1]
frames = frames * self.window # Broadcasting window
stft = np.fft.rfft(frames, n=self.fft_length, axis=-1)
magnitude_spec = np.abs(stft)
mel_spec = np.matmul(magnitude_spec, self.mel_filters)
log_mel_spec = np.log(np.maximum(mel_spec, self.mel_floor))
if self.per_bin_mean is not None:
log_mel_spec = log_mel_spec - self.per_bin_mean # Broadcasting
if self.per_bin_stddev is not None:
log_mel_spec = log_mel_spec / self.per_bin_stddev # Broadcasting
mel_spectrogram = log_mel_spec.squeeze()
mask = attention_mask[:: self.hop_length].astype(bool)
# TODO: The filtered mask is always exactly 3 elements longer than the mel_spectrogram. Why???
return mel_spectrogram, mask[: mel_spectrogram.shape[0]]
def __call__(
self,
raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
padding: Union[bool, str, PaddingStrategy] = "longest",
max_length: Optional[int] = 480_000,
truncation: bool = True,
pad_to_multiple_of: Optional[int] = 128,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: Optional[bool] = True,
**kwargs,
) -> BatchFeature:
"""Creates a batch of MEL spectrograms from the provided raw speech.
This implementation uses a different algorithm for windowing and preemphasis compared to the built-in
`transformers.audio_utils.spectrogram()` function that _will_ result in different outputs. Consider this
carefully when selecting an audio feature extactor, especially with pre-trained models.
Args:
raw_speech:
The audio for which MEL spectrograms are created.
padding (`Union[bool, str, PaddingStrategy]`, *optional*, defaults to `"longest"`):
The padding strategy to use for batches of audio with different lengths.
max_length (`int`, *optional*, defaults to 480000):
If provided, defines the maximum length of the audio to allow. Audio longer than this will be
truncated if `truncation=True`.
truncation (`bool`, *optional*, defaults to `True`):
Whether or not to truncate audio above `max_length`.
pad_to_multiple_of (`int`, *optional*, defaults to 128):
When padding, pad to a multiple of this value. The default value is defined for optimal TPU support.
return_tensors (`Union[str, TensorType]`, *optional*, defaults to `None`):
The type of tensors to return (e.g., NumPy, Torch, JAX, TensorFlow).
return_attention_mask (`bool`, *optional*, defaults to `True`):
Whether to return the attention mask for the generated MEL spectrograms.
"""
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
is_batched_sequence = isinstance(raw_speech, Sequence) and isinstance(raw_speech[0], (np.ndarray, Sequence))
is_batched = is_batched_numpy or is_batched_sequence
if is_batched:
raw_speech = [np.asarray([rs]).T for rs in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech)
if not is_batched: # always return a batch
raw_speech = [np.asarray([raw_speech])]
batched_speech = self.pad(
BatchFeature({"input_features": raw_speech}),
padding=padding,
max_length=max_length,
truncation=truncation,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
prepared_speech = []
prepared_speech_mask = []
for speech, mask in zip(batched_speech.input_features, batched_speech.attention_mask):
speech, mask = self._extract_spectrogram(speech.T, mask)
prepared_speech.append(speech.astype(np.float32))
prepared_speech_mask.append(mask)
return BatchFeature(
{"input_features": prepared_speech, "input_features_mask": prepared_speech_mask},
tensor_type=return_tensors,
)
__all__ = ["Gemma3nAudioFeatureExtractor"]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,191 @@
# coding=utf-8
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, make_nested_list_of_images
from ...processing_utils import AudioKwargs, ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
class Gemma3nImagesKwargs(ImagesKwargs):
do_pan_and_scan: Optional[bool]
pan_and_scan_min_crop_size: Optional[int]
pan_and_scan_max_num_crops: Optional[int]
pan_and_scan_min_ratio_to_activate: Optional[float]
do_convert_rgb: Optional[bool]
class Gemma3nProcessorKwargs(ProcessingKwargs, total=False):
audio_kwargs: AudioKwargs
images_kwargs: Gemma3nImagesKwargs
_defaults = {
"text_kwargs": {
"padding": False,
},
}
class Gemma3nProcessor(ProcessorMixin):
"""
A processor for Gemma 3n, wrapping the full capabilities of a feature extractor, image processor, and tokenizer
into a single processor.
Args:
feature_extractor (`Gemma3nAudioFeatureExtractor`):
Feature extractor that converts raw audio waveforms into MEL spectrograms for the audio encoder. This
should return a `BatchFeature` with `input_features` and `input_features_mask` features.
image_processor (`SiglipImageProcessorFast`):
Image processor that prepares batches of images for the vision encoder. This should return a `BatchFeature`
with a `pixel_values` feature.
tokenizer (`GemmaTokenizerFast`):
The text tokenizer for the model.
chat_template (`string`, *optional*):
A Jinja template for generating text prompts from a set of messages.
audio_seq_length (int, *optional*, defaults to 188):
The number of audio soft tokens that will be added to the text prompt
image_seq_length (int, *optional*, defaults to 256):
The number of image soft tokens that should be added to
"""
attributes = ["feature_extractor", "image_processor", "tokenizer"]
feature_extractor_class = "AutoFeatureExtractor"
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
feature_extractor,
image_processor,
tokenizer,
chat_template=None,
audio_seq_length: int = 188,
image_seq_length: int = 256,
**kwargs,
):
self.audio_seq_length = audio_seq_length
self.audio_token_id = tokenizer.audio_token_id
self.boa_token = tokenizer.boa_token
self.audio_token = tokenizer.audio_token
audio_tokens_expanded = "".join([tokenizer.audio_token] * audio_seq_length)
self.full_audio_sequence = f"\n\n{tokenizer.boa_token}{audio_tokens_expanded}{tokenizer.eoa_token}\n\n"
self.image_seq_length = image_seq_length
self.image_token_id = tokenizer.image_token_id
self.boi_token = tokenizer.boi_token
self.image_token = tokenizer.image_token
image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length)
self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
super().__init__(
feature_extractor=feature_extractor,
image_processor=image_processor,
tokenizer=tokenizer,
chat_template=chat_template,
**kwargs,
)
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
audio: Optional[Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]]] = None,
videos=None,
**kwargs: Unpack[Gemma3nProcessorKwargs],
) -> BatchFeature:
if text is None and images is None and audio is None:
raise ValueError("Provide at least one of `text`, `images`, or `audio`.")
output_kwargs = self._merge_kwargs(
Gemma3nProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
if audio is not None:
audio_inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
if not text:
text = [self.audio_token for _ in audio]
# Expand placeholder audio tokens to the full audio token sequence
text = [prompt.replace(self.audio_token, self.full_audio_sequence) for prompt in text]
else:
audio_inputs = {}
if images is not None:
batched_images = make_nested_list_of_images(images)
image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"])
# Create empty text to be replaced with placeholders
if not text:
text = [" ".join([self.image_token] * len(images)) for images in batched_images]
if len(batched_images) != len(text):
raise ValueError(
f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})."
)
# Expand placeholder image tokens to the full image token sequence
text = [prompt.replace(self.image_token, self.full_image_sequence) for prompt in text]
else:
image_inputs = {}
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
# Add token type ids manually, as tokenizer can't do arbitrary position token types
array_ids = text_inputs["input_ids"]
token_type_ids = np.zeros_like(array_ids)
token_type_ids[array_ids == self.image_token_id] = 1
token_type_ids[array_ids == self.audio_token_id] = 3
text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs
text_inputs["token_type_ids"] = token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs, **audio_inputs}, tensor_type=return_tensors)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"]
image_processor_input_names = self.image_processor.model_input_names
feature_extactor_input_names = self.feature_extractor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names + feature_extactor_input_names))
__all__ = ["Gemma3nProcessor"]

View File

@@ -1642,6 +1642,7 @@ def set_model_tester_for_less_flaky_test(test_case):
"AriaVisionText2TextModelTester", "AriaVisionText2TextModelTester",
"GPTNeoModelTester", "GPTNeoModelTester",
"DPTModelTester", "DPTModelTester",
"Gemma3nTextModelTester", # cannot have a single layer combined with the cache sharing config attrs in the tester
] ]
if test_case.model_tester.__class__.__name__ in exceptional_classes: if test_case.model_tester.__class__.__name__ in exceptional_classes:
target_num_hidden_layers = None target_num_hidden_layers = None

View File

View File

@@ -0,0 +1,277 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import os
import random
import tempfile
import unittest
from typing import Optional, Sequence
import numpy as np
from parameterized import parameterized
from transformers.models.gemma3n import Gemma3nAudioFeatureExtractor
from transformers.testing_utils import (
check_json_file_has_correct_format,
require_torch,
)
from transformers.utils.import_utils import is_torch_available
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
if is_torch_available():
pass
global_rng = random.Random()
MAX_LENGTH_FOR_TESTING = 512
def floats_list(shape, scale=1.0, rng=None):
"""Creates a random float32 tensor"""
if rng is None:
rng = global_rng
values = []
for _ in range(shape[0]):
values.append([])
for _ in range(shape[1]):
values[-1].append(rng.random() * scale)
return values
class Gemma3nAudioFeatureExtractionTester:
def __init__(
self,
parent,
batch_size=7,
min_seq_length=400,
max_seq_length=2000,
feature_size: int = 128,
sampling_rate: int = 16_000,
padding_value: float = 0.0,
return_attention_mask: bool = False,
# ignore hop_length / frame_length for now, as ms -> length conversion causes issues with serialization tests
# frame_length_ms: float = 32.0,
# hop_length: float = 10.0,
min_frequency: float = 125.0,
max_frequency: float = 7600.0,
preemphasis: float = 0.97,
preemphasis_htk_flavor: bool = True,
fft_overdrive: bool = True,
dither: float = 0.0,
input_scale_factor: float = 1.0,
mel_floor: float = 1e-5,
per_bin_mean: Optional[Sequence[float]] = None,
per_bin_stddev: Optional[Sequence[float]] = None,
):
self.parent = parent
self.batch_size = batch_size
self.min_seq_length = min_seq_length
self.max_seq_length = max_seq_length
self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1)
self.feature_size = feature_size
self.sampling_rate = sampling_rate
self.padding_value = padding_value
self.return_attention_mask = return_attention_mask
# ignore hop_length / frame_length for now, as ms -> length conversion causes issues with serialization tests
# self.frame_length_ms = frame_length_ms
# self.hop_length = hop_length
self.min_frequency = min_frequency
self.max_frequency = max_frequency
self.preemphasis = preemphasis
self.preemphasis_htk_flavor = preemphasis_htk_flavor
self.fft_overdrive = fft_overdrive
self.dither = dither
self.input_scale_factor = input_scale_factor
self.mel_floor = mel_floor
self.per_bin_mean = per_bin_mean
self.per_bin_stddev = per_bin_stddev
def prepare_feat_extract_dict(self):
return {
"feature_size": self.feature_size,
"sampling_rate": self.sampling_rate,
"padding_value": self.padding_value,
"return_attention_mask": self.return_attention_mask,
"min_frequency": self.min_frequency,
"max_frequency": self.max_frequency,
"preemphasis": self.preemphasis,
"preemphasis_htk_flavor": self.preemphasis_htk_flavor,
"fft_overdrive": self.fft_overdrive,
"dither": self.dither,
"input_scale_factor": self.input_scale_factor,
"mel_floor": self.mel_floor,
"per_bin_mean": self.per_bin_mean,
"per_bin_stddev": self.per_bin_stddev,
}
def prepare_inputs_for_common(self, equal_length=False, numpify=False):
def _flatten(list_of_lists):
return list(itertools.chain(*list_of_lists))
if equal_length:
speech_inputs = [floats_list((self.max_seq_length, self.feature_size)) for _ in range(self.batch_size)]
else:
# make sure that inputs increase in size
speech_inputs = [
floats_list((x, self.feature_size))
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
]
if numpify:
speech_inputs = [np.asarray(x) for x in speech_inputs]
return speech_inputs
class Gemma3nAudioFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = Gemma3nAudioFeatureExtractor
def setUp(self):
self.feat_extract_tester = Gemma3nAudioFeatureExtractionTester(self)
def test_feat_extract_from_and_save_pretrained(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
check_json_file_has_correct_format(saved_file)
feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
dict_first = feat_extract_first.to_dict()
dict_second = feat_extract_second.to_dict()
mel_1 = feat_extract_first.mel_filters
mel_2 = feat_extract_second.mel_filters
self.assertTrue(np.allclose(mel_1, mel_2))
self.assertEqual(dict_first, dict_second)
def test_feat_extract_to_json_file(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
json_file_path = os.path.join(tmpdirname, "feat_extract.json")
feat_extract_first.to_json_file(json_file_path)
feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path)
dict_first = feat_extract_first.to_dict()
dict_second = feat_extract_second.to_dict()
mel_1 = feat_extract_first.mel_filters
mel_2 = feat_extract_second.mel_filters
self.assertTrue(np.allclose(mel_1, mel_2))
self.assertEqual(dict_first, dict_second)
def test_feat_extract_from_pretrained_kwargs(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
check_json_file_has_correct_format(saved_file)
feat_extract_second = self.feature_extraction_class.from_pretrained(
tmpdirname, feature_size=2 * self.feat_extract_dict["feature_size"]
)
mel_1 = feat_extract_first.mel_filters
mel_2 = feat_extract_second.mel_filters
self.assertTrue(2 * mel_1.shape[1] == mel_2.shape[1])
@parameterized.expand(
[
([floats_list((1, x))[0] for x in range(800, 1400, 200)],),
([floats_list((1, x))[0] for x in (800, 800, 800)],),
([floats_list((1, x))[0] for x in range(200, (MAX_LENGTH_FOR_TESTING + 500), 200)], True),
]
)
def test_call(self, audio_inputs, test_truncation=False):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
np_audio_inputs = [np.asarray(audio_input) for audio_input in audio_inputs]
input_features = feature_extractor(np_audio_inputs, padding="max_length", return_tensors="np").input_features
self.assertTrue(input_features.ndim == 3)
# input_features.shape should be (batch, num_frames, n_mels) ~= (batch, num_frames, feature_size)
# 480_000 is the max_length that inputs are padded to. we use that to calculate num_frames
expected_num_frames = (480_000 - feature_extractor.frame_length) // (feature_extractor.hop_length) + 1
self.assertTrue(
input_features.shape[-2] == expected_num_frames,
f"no match: {input_features.shape[-1]} vs {expected_num_frames}",
)
self.assertTrue(input_features.shape[-1] == feature_extractor.feature_size)
encoded_sequences_1 = feature_extractor(audio_inputs, return_tensors="np").input_features
encoded_sequences_2 = feature_extractor(np_audio_inputs, return_tensors="np").input_features
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
if test_truncation:
audio_inputs_truncated = [x[:MAX_LENGTH_FOR_TESTING] for x in audio_inputs]
np_audio_inputs_truncated = [np.asarray(audio_input) for audio_input in audio_inputs_truncated]
encoded_sequences_1 = feature_extractor(
audio_inputs_truncated, max_length=MAX_LENGTH_FOR_TESTING, return_tensors="np"
).input_features
encoded_sequences_2 = feature_extractor(
np_audio_inputs_truncated, max_length=MAX_LENGTH_FOR_TESTING, return_tensors="np"
).input_features
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
def test_dither(self):
np.random.seed(42) # seed the dithering randn()
# Tests that features with and without little dithering are similar, but not the same
dict_no_dither = self.feat_extract_tester.prepare_feat_extract_dict()
dict_no_dither["dither"] = 0.0
dict_dither = self.feat_extract_tester.prepare_feat_extract_dict()
dict_dither["dither"] = 0.00003 # approx. 1/32k
feature_extractor_no_dither = self.feature_extraction_class(**dict_no_dither)
feature_extractor_dither = self.feature_extraction_class(**dict_dither)
# create three inputs of length 800, 1000, and 1200
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
# compute features
input_features_no_dither = feature_extractor_no_dither(
np_speech_inputs, padding=True, return_tensors="np", sampling_rate=dict_no_dither["sampling_rate"]
).input_features
input_features_dither = feature_extractor_dither(
np_speech_inputs, padding=True, return_tensors="np", sampling_rate=dict_dither["sampling_rate"]
).input_features
# test there is a difference between features (there's added noise to input signal)
diff = input_features_dither - input_features_no_dither
# features are not identical
self.assertTrue(np.abs(diff).mean() > 1e-6)
# features are not too different
self.assertTrue(np.abs(diff).mean() <= 1e-4)
self.assertTrue(np.abs(diff).max() <= 5e-3)
@require_torch
def test_double_precision_pad(self):
import torch
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
np_speech_inputs = np.random.rand(100, 32).astype(np.float64)
py_speech_inputs = np_speech_inputs.tolist()
for inputs in [py_speech_inputs, np_speech_inputs]:
np_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="np")
self.assertTrue(np_processed.input_features.dtype == np.float32)
pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
self.assertTrue(pt_processed.input_features.dtype == torch.float32)

View File

@@ -0,0 +1,886 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Gemma3n model."""
import tempfile
import unittest
import numpy as np
import pytest
from datasets import load_dataset
from parameterized import parameterized
from transformers import (
AutoModelForCausalLM,
AutoProcessor,
AutoTokenizer,
Gemma3nAudioConfig,
Gemma3nAudioFeatureExtractor,
Gemma3nConfig,
Gemma3nTextConfig,
GenerationConfig,
is_torch_available,
)
from transformers.testing_utils import (
cleanup,
require_flash_attn,
require_read_token,
require_torch,
require_torch_gpu,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ..gemma.test_modeling_gemma import GemmaModelTester
if is_torch_available():
import torch
from transformers import (
Gemma3nAudioEncoder,
Gemma3nForCausalLM,
Gemma3nForConditionalGeneration,
Gemma3nModel,
Gemma3nTextModel,
)
class Gemma3nAudioModelTester:
def __init__(
self,
parent,
batch_size=2,
num_channels=32, # feature_size / input_feat_size
sampling_rate=16_000,
raw_audio_length=8_000,
is_training=True,
):
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.sampling_rate = sampling_rate
self.raw_audio_length = raw_audio_length
self.is_training = is_training
def get_feature_extractor_config(self):
return {
"feature_size": self.num_channels,
"sampling_rate": self.sampling_rate,
"padding_value": 0.0,
"return_attention_mask": True,
"frame_length_ms": 32.0,
"hop_length_ms": 10.0,
"dither": 0.0, # Important for determinism
}
def get_audio_encoder_config(self):
return Gemma3nAudioConfig(
input_feat_size=self.num_channels,
hidden_size=32,
conf_num_attention_heads=4,
conf_num_hidden_layers=2,
sscp_conv_channel_size=(16, 8),
conf_conv_kernel_size=3,
conf_attention_chunk_size=4,
conf_attention_context_left=5,
)
def prepare_config_and_inputs_for_common(self):
# Prepare inputs for the audio encoder
feature_extractor_config = self.get_feature_extractor_config()
audio_encoder_config = self.get_audio_encoder_config()
np.random.seed(0)
raw_speech_1 = np.sin(2 * np.pi * 440 * np.linspace(0, 1, self.raw_audio_length)).astype(np.float32)
raw_speech_2 = np.random.randn(self.raw_audio_length // 2).astype(np.float32)
raw_speech = [raw_speech_1, raw_speech_2]
feature_extractor = Gemma3nAudioFeatureExtractor(**feature_extractor_config)
audio_inputs = feature_extractor(raw_speech, return_tensors="pt")
input_features = audio_inputs["input_features"]
# The encoder expects a padding mask (True for padding), while the feature extractor
# returns an attention mask (True for valid tokens). We must invert it.
input_features_mask = ~audio_inputs["input_features_mask"].to(torch.bool)
inputs_dict = {
"audio_mel": input_features,
"audio_mel_mask": input_features_mask,
}
return audio_encoder_config, inputs_dict
@unittest.skip("Skipped for now!")
@require_torch
class Gemma3nAudioModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (Gemma3nAudioEncoder,) if is_torch_available() else ()
test_pruning = False
test_head_masking = False
test_missing_keys = False
is_generative = False
_is_stateful = True
main_input_name = "audio_mel"
test_initialization = False
test_can_init_all_missing_weights = False
def setUp(self):
self.model_tester = Gemma3nAudioModelTester(self)
self.config_tester = ConfigTester(self, config_class=Gemma3nAudioConfig, hidden_size=37)
torch.manual_seed(0)
# The following values are golden outputs from a deterministic run of the components.
# They are used to ensure that changes to the code do not alter the numerical output.
# Generated with seeds np.random.seed(0) and torch.manual_seed(0).
self.expected_input_features_shape = (2, 48, 32)
self.expected_input_features_slice = np.array([-5.733152, -5.337127, -4.916284, -4.378989, -3.7622747])
self.expected_input_features_mask_shape = (2, 48)
self.expected_input_features_mask_slice = np.array([True, True, True, True, False])
self.expected_encoder_output_shape = (2, 3, 32)
self.expected_encoder_output_slice = torch.tensor([-0.4159, 0.6459, 0.6305, 2.2902, 0.9683])
self.expected_encoder_mask_shape = (2, 3)
self.expected_encoder_mask_slice = torch.tensor([False, False, True])
# Prepare a shared feature extractor and raw audio for the tests
self.feature_extractor = Gemma3nAudioFeatureExtractor(**self.model_tester.get_feature_extractor_config())
np.random.seed(0)
raw_speech_1 = np.sin(2 * np.pi * 440 * np.linspace(0, 1, self.model_tester.raw_audio_length)).astype(
np.float32
)
raw_speech_2 = np.random.randn(self.model_tester.raw_audio_length // 2).astype(np.float32)
self.raw_speech = [raw_speech_1, raw_speech_2]
@unittest.skip("Audio encoder does not support attention output")
def test_attention_outputs(self):
pass
@unittest.skip("Audio encoder does not support hidden state output")
def test_hidden_states_output(self):
pass
@unittest.skip("Audio encoder returns a tuple, not a ModelOutput object, skipping equivalence test.")
def test_model_outputs_equivalence(self):
pass
@unittest.skip("Audio encoder does not support retaining gradients on hidden states/attentions.")
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip("Audio encoder does not have a concept of token embeddings")
def test_model_get_set_embeddings(self):
pass
@unittest.skip("Audio encoder does not have a concept of token embeddings")
def test_resize_tokens_embeddings(self):
pass
@unittest.skip("This model has a complex downsampling scheme that is hard to test with the generic batching test.")
def test_batching_equivalence(self):
pass
def test_feature_extractor(self):
"""
Tests the feature extractor's output against pre-computed golden values.
This ensures the NumPy-based audio preprocessing is correct and consistent.
"""
audio_inputs = self.feature_extractor(
self.raw_speech, padding="longest", pad_to_multiple_of=128, return_tensors="np"
)
input_features = audio_inputs["input_features"]
self.assertEqual(input_features.shape, self.expected_input_features_shape)
np.testing.assert_allclose(input_features[0, 0, :5], self.expected_input_features_slice, rtol=1e-5, atol=1e-5)
print(input_features[0, 0, :5])
input_features_mask = audio_inputs["input_features_mask"]
self.assertEqual(input_features_mask.shape, self.expected_input_features_mask_shape)
# The second audio sample is shorter (22 frames vs 48), so its mask should become False at index 22
np.testing.assert_array_equal(input_features_mask[1, 21:26], self.expected_input_features_mask_slice)
def test_audio_encoder(self):
"""
Tests the audio encoder's forward pass against pre-computed golden values.
This ensures the PyTorch-based audio encoding model is correct and consistent.
"""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = Gemma3nAudioEncoder(config).to(torch_device).eval()
with torch.no_grad():
encoder_output, encoder_mask = model(**inputs_dict)
print(encoder_output[0, 0, :5])
# Check output encodings
self.assertEqual(encoder_output.shape, self.expected_encoder_output_shape)
torch.testing.assert_close(
encoder_output[0, 0, :5], self.expected_encoder_output_slice.to(torch_device), rtol=1e-4, atol=1e-4
)
# Check output mask (True means padded)
# Second sample has 22 feature frames. After downsampling by 4 (conv) -> 5 frames. After downsampling by 4 (reduction) -> 1 frame.
# So the mask should be [False, True, True]
self.assertEqual(encoder_mask.shape, self.expected_encoder_mask_shape)
torch.testing.assert_close(encoder_mask[1, :], self.expected_encoder_mask_slice.to(torch_device))
class Gemma3nTextModelTester(GemmaModelTester):
activation_sparsity_pattern = None
forced_config_args = ["activation_sparsity_pattern"]
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
vocab_size=99,
vocab_size_per_layer_input=99,
hidden_size=16,
num_hidden_layers=4, # override to correctly test sharing cache pattern
num_kv_shared_layers=2, # important to override
layer_types=[
"full_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
], # similarly we want to test sharing on both types
num_attention_heads=2,
num_key_value_heads=2,
altup_num_inputs=2,
intermediate_size=21,
hidden_activation="gelu_pytorch_tanh",
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
is_decoder=False,
):
self._verify_model_attributes()
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.vocab_size_per_layer_input = vocab_size_per_layer_input
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_kv_shared_layers = num_kv_shared_layers
self.layer_types = layer_types
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.altup_num_inputs = altup_num_inputs
self.intermediate_size = intermediate_size
self.hidden_activation = hidden_activation
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.head_dim = self.hidden_size // self.num_attention_heads
self.is_decoder = is_decoder
if is_torch_available():
config_class = Gemma3nTextConfig
model_class = Gemma3nTextModel
for_causal_lm_class = Gemma3nForCausalLM
@unittest.skip("Skipped for now!")
@require_torch
class Gemma3nTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (Gemma3nTextModel, Gemma3nForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (Gemma3nForCausalLM,) if is_torch_available() else ()
test_headmasking = False
test_pruning = False
_is_stateful = True
model_split_percents = [0.5, 0.6]
def setUp(self):
self.model_tester = Gemma3nTextModelTester(self)
self.config_tester = ConfigTester(
self,
config_class=Gemma3nConfig,
hidden_size=37,
text_config={"activation_sparsity_pattern": None},
)
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
):
"Gemma3n has special hidden states shape with 1 additional dim (which is then reduced with projections)"
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (output_length - prompt_length))
# When `output_hidden_states=True`, each iteration of generate appends the hidden states corresponding to the
# new token(s)
# NOTE: `HybridCache` may have different lengths on different layers, if this test starts failing add more
# elaborate checks
for generated_length, iter_hidden_states in enumerate(hidden_states):
# regardless of using cache, the first forward pass will have the full prompt as input
if use_cache and generated_length > 0:
model_input_length = 1
else:
model_input_length = prompt_length + generated_length
expected_shape = (config.altup_num_inputs, batch_size, model_input_length, config.hidden_size)
# check hidden size
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
[expected_shape] * len(iter_hidden_states),
)
class Gemma3nVision2TextModelTester:
text_config = {"activation_sparsity_pattern": None}
forced_config_args = ["text_config"]
def __init__(
self,
parent,
mm_tokens_per_image=2,
image_token_index=1,
boi_token_index=2,
eoi_token_index=3,
seq_length=25,
is_training=True,
vision_config={
"use_labels": True,
"image_size": 20,
"patch_size": 5,
"num_channels": 3,
"is_training": True,
"hidden_size": 32,
"num_key_value_heads": 1,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"intermediate_size": 37,
"dropout": 0.1,
"attention_dropout": 0.1,
"initializer_range": 0.02,
},
use_cache=False,
):
self.parent = parent
# `image_token_index` is set to 0 to pass "resize_embeddings" test, do not modify
self.mm_tokens_per_image = mm_tokens_per_image
self.image_token_index = image_token_index
self.boi_token_index = boi_token_index
self.eoi_token_index = eoi_token_index
self.llm_tester = Gemma3nTextModelTester(self.parent)
self.text_config = self.llm_tester.get_config()
self.vision_config = vision_config
self.seq_length = seq_length
self.pad_token_id = self.text_config.pad_token_id
self.num_hidden_layers = self.text_config.num_hidden_layers
self.vocab_size = self.text_config.vocab_size
self.hidden_size = self.text_config.hidden_size
self.num_attention_heads = self.text_config.num_attention_heads
self.is_training = is_training
self.batch_size = 3
self.num_channels = vision_config["num_channels"]
self.image_size = vision_config["image_size"]
self.encoder_seq_length = seq_length
self.use_cache = use_cache
def get_config(self):
return Gemma3nConfig(
text_config=self.text_config,
vision_config=self.vision_config,
image_token_index=self.image_token_index,
boi_token_index=self.boi_token_index,
eoi_token_index=self.eoi_token_index,
mm_tokens_per_image=self.mm_tokens_per_image,
)
def prepare_config_and_inputs(self):
pixel_values = floats_tensor(
[
self.batch_size,
self.vision_config["num_channels"],
self.vision_config["image_size"],
self.vision_config["image_size"],
]
)
config = self.get_config()
return config, pixel_values
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
# set the 3 first tokens to be image, and ensure that no other tokens are image tokens
# do not change this unless you modified image size or patch size
input_ids[input_ids == config.image_token_index] = self.pad_token_id
input_ids[:, :1] = config.image_token_index
token_type_ids = torch.zeros_like(input_ids)
token_type_ids[input_ids == config.image_token_index] = 1
inputs_dict = {
"pixel_values": pixel_values,
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
return config, inputs_dict
@unittest.skip("Skipped for now!")
@require_torch
class Gemma3nVision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (Gemma3nModel, Gemma3nForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (Gemma3nForConditionalGeneration,) if is_torch_available() else ()
test_headmasking = False
test_pruning = False
test_missing_keys = False
_is_stateful = True
model_split_percents = [0.5, 0.6]
# MP works but offload doesn't work when the SigLIP MultiheadAttention is offloaded
# TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"]
# in the dispatch_model function
test_cpu_offload = False
test_disk_offload_safetensors = False
test_disk_offload_bin = False
def setUp(self):
self.model_tester = Gemma3nVision2TextModelTester(self)
self.config_tester = ConfigTester(
self,
config_class=Gemma3nConfig,
hidden_size=37,
text_config={"activation_sparsity_pattern": None},
)
@unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(
reason="HybridCache can't be gathered because it is not iterable. Adding a simple iter and dumping `distributed_iterator`"
" as in Dynamic Cache doesnt work. NOTE: @gante all cache objects would need better compatibility with multi gpu setting"
)
def test_multi_gpu_data_parallel_forward(self):
pass
@unittest.skip("Failing because of unique cache (HybridCache)")
def test_model_outputs_equivalence(self, **kwargs):
pass
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Gemma3n has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("Gemma3n has HybridCache which is not compatible with assisted decoding")
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@pytest.mark.generate
@unittest.skip("Gemma3n has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@unittest.skip("Gemma3n has HybridCache which is not compatible with dola decoding")
def test_dola_decoding_sample(self):
pass
@unittest.skip("Gemma3n has HybridCache and doesn't support continue from past kv")
def test_generate_continue_from_past_key_values(self):
pass
@unittest.skip("Gemma3n has HybridCache and doesn't support low_memory generation")
def test_beam_search_low_memory(self):
pass
@unittest.skip("Gemma3n has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate(self):
pass
@unittest.skip("Gemma3n has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("Gemma3n has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip("Gemma3n has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_with_static_cache(self):
pass
@unittest.skip("Gemma3n has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip(
reason="Siglip (vision backbone) uses the same initialization scheme as the Flax original implementation"
)
def test_initialization(self):
pass
@unittest.skip(
reason="Siglip has no FLEX attention, and we don't have a proper way to set/test attn in VLMs. TODO @raushan"
)
def test_flex_attention_with_grads(self):
pass
def test_automodelforcausallm(self):
"""
Regression test for #36741 -- make sure `AutoModelForCausalLM` works with a Gemma3n config, i.e. that
`AutoModelForCausalLM.from_pretrained` pulls the text config before loading the model
"""
config = self.model_tester.get_config()
model = Gemma3nForConditionalGeneration(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
for_causal_lm = AutoModelForCausalLM.from_pretrained(tmp_dir)
self.assertIsInstance(for_causal_lm, Gemma3nForCausalLM)
@unittest.skip("Skipped for now!")
@slow
@require_torch_gpu
@require_read_token
class Gemma3nIntegrationTest(unittest.TestCase):
def setUp(self):
self.processor = AutoProcessor.from_pretrained("Google/gemma-3n-E4B-it", padding_side="left")
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
self.messages = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{
"role": "user",
"content": [
{"type": "image", "url": url},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
audio_ds = load_dataset(
"etechgrid/28.5k_wavfiles_dataset", "default", data_files="wav_dataset/103-1240-0000.wav"
)
self.audio_file_path = audio_ds["train"][0]["audio"]["path"]
def tearDown(self):
cleanup(torch_device, gc_collect=True)
def test_model_4b_bf16(self):
model_id = "Google/gemma-3n-E4B-it"
model = Gemma3nForConditionalGeneration.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
).to(torch_device)
inputs = self.processor.apply_chat_template(
self.messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(torch_device)
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear blue water and a blue sky in the background. It looks like'] # fmt: skip
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_with_audio(self):
"""
Tests the full model pipeline with batched audio inputs provided as file paths.
This ensures the processor correctly loads and processes audio files.
"""
model_id = "Google/gemma-3n-E4B-it"
model = Gemma3nForConditionalGeneration.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
).to(torch_device)
messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "Transcribe the following speech segment in English:"},
{"type": "audio", "audio": str(self.audio_file_path)},
],
}
],
]
inputs = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
padding=True,
return_tensors="pt",
).to(torch_device, dtype=model.dtype)
input_len = inputs["input_ids"].shape[-1]
output = model.generate(**inputs, max_new_tokens=16, do_sample=False)
output = output[:, input_len:]
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = ["Chapter 1. Mrs. Rachel Lind is surprised.\n\nMrs. Rachel Lind"]
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_4b_batch(self):
model_id = "Google/gemma-3n-E4B-it"
model = Gemma3nForConditionalGeneration.from_pretrained(
model_id, low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
).to(torch_device)
messages_2 = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{
"role": "user",
"content": [
{
"type": "image",
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
},
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": "Are these images identical?"},
],
},
]
inputs = self.processor.apply_chat_template(
[self.messages, messages_2],
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=True,
add_generation_prompt=True,
).to(torch_device)
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = [
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like',
"user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a cow"
] # fmt: skip
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_4b_crops(self):
model_id = "Google/gemma-3n-E4B-it"
model = Gemma3nForConditionalGeneration.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
).to(torch_device)
crop_config = {
"images_kwargs": {
"do_pan_and_scan": True,
"pan_and_scan_max_num_crops": 448,
"pan_and_scan_min_crop_size": 32,
"pan_and_scan_min_ratio_to_activate": 0.3,
}
}
inputs = self.processor.apply_chat_template(
self.messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
**crop_config,
).to(torch_device)
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_NUM_IMAGES = 3 # one for the origin image and two crops of images
EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a beach with a turquoise ocean and blue sky in the background.'] # fmt: skip
self.assertEqual(len(inputs["pixel_values"]), EXPECTED_NUM_IMAGES)
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_4b_multiimage(self):
model_id = "Google/gemma-3n-E4B-it"
model = Gemma3nForConditionalGeneration.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
).to(torch_device)
messages = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{
"role": "user",
"content": [
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": "What do you see here?"},
],
},
]
inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=True,
add_generation_prompt=True,
).to(torch_device)
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Overall Scene:**\n\nIt looks like a street scene in a vibrant,"] # fmt: skip
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_1b_text_only(self):
model_id = "google/gemma-3-1b-it"
model = Gemma3nForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
torch_device
)
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
inputs = tokenizer("Write a poem about Machine Learning.", return_tensors="pt").to(torch_device)
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a river deep,\nWith patterns hidden, secrets sleep.\nA neural net, a watchful eye,\nLearning'] # fmt: skip
self.assertEqual(output_text, EXPECTED_TEXTS)
# TODO: raushan FA2 generates gibberish for no reason, check later
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
def test_model_4b_flash_attn(self):
model_id = "Google/gemma-3n-E4B-it"
model = Gemma3nForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
).to(torch_device)
inputs = self.processor.apply_chat_template(
self.messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(torch_device)
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean. It looks like a very sunny and'] # fmt: skip
self.assertEqual(output_text, EXPECTED_TEXTS)
@parameterized.expand([("flash_attention_2",), ("sdpa",), ("eager",)])
def test_generation_beyond_sliding_window(self, attn_implementation: str):
"""Test that we can correctly generate beyond the sliding window. This is non trivial as
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
Outputs for every attention functions should be coherent and identical.
"""
model_id = "google/gemma-3-1b-it"
input_text = [
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
"A list of colors: red, blue", # This will almost all be padding tokens
]
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
).to(torch_device)
# Make sure prefill is larger than sliding window
input_size = inputs.input_ids.shape[-1]
self.assertTrue(input_size > model.config.sliding_window)
out = model.generate(**inputs, max_new_tokens=20)[:, input_size:]
output_text = tokenizer.batch_decode(out)
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
def test_generation_beyond_sliding_window_with_generation_config(self):
"""
Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 --
ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
"""
model_id = "google/gemma-3-1b-it"
attn_implementation = "sdpa"
input_text = [
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
"A list of colors: red, blue", # This will almost all be padding tokens
]
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
).to(torch_device)
# Make sure prefill is larger than sliding window
input_size = inputs.input_ids.shape[-1]
self.assertTrue(input_size > model.config.sliding_window)
generation_config = GenerationConfig(max_new_tokens=20)
out = model.generate(**inputs, generation_config=generation_config)[:, input_size:]
output_text = tokenizer.batch_decode(out)
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
self.assertEqual(output_text, EXPECTED_COMPLETIONS)

View File

@@ -0,0 +1,185 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
import numpy as np
from parameterized import parameterized
from transformers import GemmaTokenizerFast, SiglipImageProcessorFast, is_speech_available
from transformers.testing_utils import require_sentencepiece, require_torch, require_torchaudio, require_vision
from .test_feature_extraction_gemma3n import floats_list
if is_speech_available():
from transformers.models.gemma3n import Gemma3nAudioFeatureExtractor, Gemma3nProcessor
@require_torch
@require_torchaudio
@require_vision
@require_sentencepiece
class Gemma3nProcessorTest(unittest.TestCase):
def setUp(self):
# TODO: update to google?
self.model_id = "Google/gemma-3n-E4B-it"
self.tmpdirname = tempfile.mkdtemp(suffix="gemma3n")
self.maxDiff = None
def get_tokenizer(self, **kwargs):
return GemmaTokenizerFast.from_pretrained(self.model_id, **kwargs)
def get_feature_extractor(self, **kwargs):
return Gemma3nAudioFeatureExtractor.from_pretrained(self.model_id, **kwargs)
def get_image_processor(self, **kwargs):
return SiglipImageProcessorFast.from_pretrained(self.model_id, **kwargs)
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_save_load_pretrained_default(self):
# NOTE: feature_extractor and image_processor both use the same filename, preprocessor_config.json, when saved to
# disk, but the files are overwritten by processor.save_pretrained(). This test does not attempt to address
# this potential issue, and as such, does not guarantee content accuracy.
tokenizer = self.get_tokenizer()
feature_extractor = self.get_feature_extractor()
image_processor = self.get_image_processor()
processor = Gemma3nProcessor(
tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
)
processor.save_pretrained(self.tmpdirname)
processor = Gemma3nProcessor.from_pretrained(self.tmpdirname)
self.assertIsInstance(processor.tokenizer, GemmaTokenizerFast)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
self.assertIsInstance(processor.feature_extractor, Gemma3nAudioFeatureExtractor)
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
def test_save_load_pretrained_additional_features(self):
tokenizer = self.get_tokenizer()
feature_extractor = self.get_feature_extractor()
image_processor = self.get_image_processor()
processor = Gemma3nProcessor(
tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
)
processor.save_pretrained(self.tmpdirname)
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS-BOS)", eos_token="(EOS-EOS)")
feature_extractor_add_kwargs = self.get_feature_extractor(dither=5.0, padding_value=1.0)
processor = Gemma3nProcessor.from_pretrained(
self.tmpdirname, bos_token="(BOS-BOS)", eos_token="(EOS-EOS)", dither=5.0, padding_value=1.0
)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
self.assertIsInstance(processor.tokenizer, GemmaTokenizerFast)
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, Gemma3nAudioFeatureExtractor)
@parameterized.expand([256, 512, 768, 1024])
def test_image_processor(self, image_size: int):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
image_processor = self.get_image_processor()
processor = Gemma3nProcessor(
tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
)
raw_image = np.random.randint(0, 256, size=(image_size, image_size, 3), dtype=np.uint8)
input_image_processor = image_processor(raw_image, return_tensors="pt")
input_processor = processor(text="Describe:", images=raw_image, return_tensors="pt")
for key in input_image_processor.keys():
self.assertAlmostEqual(input_image_processor[key].sum(), input_processor[key].sum(), delta=1e-2)
if "pixel_values" in key:
# NOTE: all images should be re-scaled to 768x768
self.assertEqual(input_image_processor[key].shape, (1, 3, 768, 768))
self.assertEqual(input_processor[key].shape, (1, 3, 768, 768))
def test_audio_feature_extractor(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
image_processor = self.get_image_processor()
processor = Gemma3nProcessor(
tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
)
raw_speech = floats_list((3, 1000))
input_feat_extract = feature_extractor(raw_speech, return_tensors="pt")
input_processor = processor(text="Transcribe:", audio=raw_speech, return_tensors="pt")
for key in input_feat_extract.keys():
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
def test_tokenizer(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
image_processor = self.get_image_processor()
processor = Gemma3nProcessor(
tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
)
input_str = "This is a test string"
encoded_processor = processor(text=input_str)
encoded_tok = tokenizer(input_str)
for key in encoded_tok.keys():
self.assertListEqual(encoded_tok[key], encoded_processor[key][0])
def test_tokenizer_decode(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
image_processor = self.get_image_processor()
processor = Gemma3nProcessor(
tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
)
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
decoded_processor = processor.batch_decode(predicted_ids)
decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
image_processor = self.get_image_processor()
processor = Gemma3nProcessor(
tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
)
for key in feature_extractor.model_input_names:
self.assertIn(
key,
processor.model_input_names,
)
for key in image_processor.model_input_names:
self.assertIn(
key,
processor.model_input_names,
)

View File

@@ -277,6 +277,7 @@ SPECIAL_CASES_TO_ALLOW = {
], ],
"Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"], "Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"],
"SmolLM3Config": ["no_rope_layer_interval"], "SmolLM3Config": ["no_rope_layer_interval"],
"Gemma3nVisionConfig": ["architecture", "do_pooling", "model_args"], # this is for use in `timm`
} }

View File

@@ -79,6 +79,7 @@ ALWAYS_OVERRIDE = ["labels"]
# docstrings instead. If formatting should be ignored for the docstring, you can put a comment # no-format on the # docstrings instead. If formatting should be ignored for the docstring, you can put a comment # no-format on the
# line before the docstring. # line before the docstring.
OBJECTS_TO_IGNORE = [ OBJECTS_TO_IGNORE = [
"Gemma3nVisionConfig",
"Llama4Processor", "Llama4Processor",
# Deprecated # Deprecated
"InputExample", "InputExample",