Gemma 3n (#39059)
* Gemma 3n * initial commit of Gemma 3n scaffold * Fixing param pass through on Gemm3p5RMSNorm * Adds Einsum layer to Gemma 3n * Updating EinsumLayer API * Undoing erroneous force push * Reverting RMSNorm to with_scale by default * Adds LAuReL to Gemma 3n * Adds AltUp to Gemma 3n * Adding Gemma3p5 overall and text config with vision and audio config placeholders (#3) * Adding gemma3p5 text configs * Adding audio config placeholders * Adding a placeholder for vision configs * Updating MobileNetVisionConfig, inheriting TimmWrapperConfig * Updating text configs * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Removing altup configs to accept the suggested configs * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Updating altup config * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Addressing review comments and updating text configs * Adding a config for activation sparsity * Updating configs to pass through options to super class init and adjust some name prefixes * Updating laurel and altup with corrected config values * Normalizing sub_config initializers --------- Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Updating MLP with activation sparsity (#2) * Updating DecoderBlock for Gemma 3n (#3) * Initial Gemm3nTextModel (#4) NOTE: This implementation WILL CHANGE in the coming weeks, however, changes will be strictly additive and this will remain a suitable baseline for downstream implementations to reference. * Adding KV Cache Sharing * Adds Einsum layer to Gemma 3n * Updating EinsumLayer API * Refactored kv cache sharing in attention * Adding KVStore for cache sharing * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update src/transformers/cache_utils.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Undoing erroneous force push * Reverting RMSNorm to with_scale by default * Adds LAuReL to Gemma 3n * Updating KV Cache Sharing implementation * Updating the q and k norm definitions in the attention module * Fixing name error for q,k,v RMS norm to use the right 3n module * Updating MLP with activation sparsity * Updating DecoderBlock for Gemma 3.5 * Updating kv cache sharing implementation with the use of a cache buffer and refactoring some lines of code * Isolating KV Cache logic to relevant components * Fixing logic error in Gemma3nAttention.forward * Refactoring caching contributions and fixing kv_store initialization * Simplifying Configs * Remove errant self from super init call * Bug fix in the Attention module - changing self.head_dim to config.head_dim * Bug fixes in the LaurelBlock and RMS Norm super init call * removing redundant code from a merge * Adding per_layer_inputs to TextModel * Adding preprocess embeddings with altup * Adds per-layer-to-single output and a host of TODOs * Integrating altup predict with the model workflow and other minor bug fixes * Using nn.Embedding temporarily for text model * It goes forward * Minor refactor of attention sparsity and RoPE initialization * Fixing duplicate rope_scaling param bug when loading from pretrained --------- Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com> Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> * Normalizing on altup_num_inputs config option * regenerating modeling file after syncing to HEAD * Use torch.std(..., unbiased=False) for activation sparsity (#8) * Refactoring to a single QVK Norm (#13) * AltUp: support scale_corrected_output (#14) * Converts einsums to nn.Linear (#7) * Converts einsums to nn.Linear * Removing unused variables * Aligning SharedKVCache with HybridCache (#11) * Alinging SharedKVStore with HybridCache * Remove KVStore. Refactor apply_rotary_pos_emb for sharing * Addressing review comments * Supporting split modality embeddings in Gemma3n (#10) * Adding the Embedder class * Update modular Co-authored-by: Ryan Mullins <ryan@ryanmullins.org> * Update modular Co-authored-by: Ryan Mullins <ryan@ryanmullins.org> * Update modular Co-authored-by: Ryan Mullins <ryan@ryanmullins.org> * Update modular Co-authored-by: Ryan Mullins <ryan@ryanmullins.org> * Update modular Co-authored-by: Ryan Mullins <ryan@ryanmullins.org> * Update modular Co-authored-by: Ryan Mullins <ryan@ryanmullins.org> * Addressing review comments, adding audio embedding layers, integrating embedder with the remaining architecture, adding a forward method for conditional generation * Apply suggestions from code review Co-authored-by: Ryan Mullins <ryan@ryanmullins.org> * Update modular Co-authored-by: Ryan Mullins <ryan@ryanmullins.org> * Addressing review comments, prop drilling audio and vision configs to the text config * Removing TODO's that have been addressed * Simplify Embedder init and add audio embeddings * Embeddings refactor. Adds Gemma3nAudioEmbedder and Gemma3nVisionEmbedder * Refactoring vision and audio embeddings into ConditionalGeneration model --------- Co-authored-by: Ryan Mullins <ryan@ryanmullins.org> Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Updating attention mask for Gemma 3.5 (#15) * xxx_token_index to xxx_token_id * remvoing deprecated last_cache_position * Removing references to SigLIP * Always init per-layer inputs * Using torch.finfo().min for epsilon_tensor * Gemma3nDecoderLayer inherits from Gemma3DecoderLayer. Remove gating lambdas * fix modular GEMMA3N_INPUTS_DOCSTRING * Gemma3nAttention inherits from Gemma3Attention * Modular inheritance fixes * CausalLM conversion script for 4B model (#16) * Add Gemma3n Audio Encoder (#6) * initial commit of Gemma 3.5 scaffold * Fixing param pass through on Gemm3nRMSNorm * Adds Einsum layer to Gemma 3.5 * Updating EinsumLayer API * Undoing erroneous force push * Reverting RMSNorm to with_scale by default * Adds LAuReL to Gemma 3n * Adds AltUp to Gemma 3n * Adding Gemma3n overall and text config with vision and audio config placeholders (#3) * Adding gemma3n text configs * Adding audio config placeholders * Adding a placeholder for vision configs * Updating MobileNetVisionConfig, inheriting TimmWrapperConfig * Updating text configs * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Removing altup configs to accept the suggested configs * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Updating altup config * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Addressing review comments and updating text configs * Adding a config for activation sparsity * Updating configs to pass through options to super class init and adjust some name prefixes * Updating laurel and altup with corrected config values * Normalizing sub_config initializers --------- Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Updating MLP with activation sparsity (#2) * Updating DecoderBlock for Gemma 3.5 (#3) * Initial Gemm3nTextModel (#4) NOTE: This implementation WILL CHANGE in the coming weeks, however, changes will be strictly additive and this will remain a suitable baseline for downstream implementations to reference. * Adding KV Cache Sharing * Adds Einsum layer to Gemma 3.5 * Updating EinsumLayer API * Refactored kv cache sharing in attention * Adding KVStore for cache sharing * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update src/transformers/cache_utils.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Undoing erroneous force push * Reverting RMSNorm to with_scale by default * Adds LAuReL to Gemma 3n * Updating KV Cache Sharing implementation * Updating the q and k norm definitions in the attention module * Fixing name error for q,k,v RMS norm to use the right Gemma 3n module * Updating MLP with activation sparsity * Updating DecoderBlock for Gemma 3.5 * Updating kv cache sharing implementation with the use of a cache buffer and refactoring some lines of code * Isolating KV Cache logic to relevant components * Fixing logic error in Gemma3nAttention.forward * Refactoring caching contributions and fixing kv_store initialization * Simplifying Configs * Remove errant self from super init call * Bug fix in the Attention module - changing self.head_dim to config.head_dim * Bug fixes in the LaurelBlock and RMS Norm super init call * removing redundant code from a merge * Adding per_layer_inputs to TextModel * Adding preprocess embeddings with altup * Adds per-layer-to-single output and a host of TODOs * Integrating altup predict with the model workflow and other minor bug fixes * Using nn.Embedding temporarily for text model * It goes forward * Minor refactor of attention sparsity and RoPE initialization * Fixing duplicate rope_scaling param bug when loading from pretrained --------- Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com> Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> * Normalizing on altup_num_inputs config option * Adding audio encoder config * Adds high-level components for Audio Encoder * Implement uniform reducer for Audio Encoder * Adding placeholders for Conformer components in Audio Encoder * Adding placeholders for SubSampleConvProjection components in Audio Encoder * Adding SequenceLayer component placeholders * Implementing Gemma3nAudioEncoder with nn.Sequential * Implementing Gemma3nAudioSubSampleConvProjection with nn.Sequential * Implementing Conformer model with SequenceLayers * Use OrderedDict in nn.Sequential initializers * Implements sl.Residual in Torch with nn.Sequential and OrderedDict * Adopting a base SequenceLayer class with default forward() method * Implementing sl.GatedLinearUnit in Torch * Implementing sl.Swish in Torch * Implementing sl.ReLU in Torch * Implementing sl.Scale in Torch * Removing sl.Dropout after tree-shaking * Implementing sl.RMSNorm in Torch with fake shape * Implementing sl.GroupNorm in Torch * Implementing sl.Conv2d in Torch * Implementing sl.Dense in Torch * Removing sl.Delay layers, which act as pass-throughs * Connecting shapes to configs in initializers * Removing sl.Emit * Implementing sl.ExpandDims in Torch * Adding sl.GradientClipping to Torch * Implementing sl.DenseShaped in Torch * Implementing sl.LDPA in Torch * Removing unused sl.CombinedQKVProj class * Fixing erroneous type hint * Implemnenting sl.DepthwiseConv1D in Torch * Implementing sl.MaskInvalid in Torch * Fixes for initialization * Fixes for saving weights * Removing einsums per feedback from HF staff * Removing Sequence Layers idioms from audio encoder * Fixes for reviewer comments * CausalLM conversion script for 4B model * inv_timescales to non-persistent buffer * Addressing audio encoder Attention feedback * Addressing Gemma3nAudioSSCPConvBlock feedback * Addressing Gemma3nAudioConformerAttention feedback * Addressing padding feedback * Weights conversion loads audio state dict * Always use vision_config so saving works * Token id updates for configs * Stubs for interleaving audio embs * Addressing reviewer feedback --------- Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com> * Fixing cache access error * Removing duplicate code from a bad merge * Gemma 3n Text + Vision Part 1 (#17) * testing utilities for numerics comparisons * Corrected einsum to nn.Linear weights conversion * Inherit scaled word embs from Gemma3 not Bart * Fixing transposes for collapsed linears * More transpose fixes * numpy api fix * RMSNorm: Explicit kwargs, scale_shift=0.0 when with_scale=True * Force AltUp to float32 * Updating debugging script for AudioEncoder debugging * Support divide_weight_by_sqrt_fan_in from JAX for per-layer inputs * Correcting attention einsum conversions * RMSNorm in type of x * Fixing douplicate laurel norm/gating * KV sharing using the right previous indices * Refactor kv shared index computation. Correct frac_shared_layers * Use num_shared_layers instead of inferring from a fraction * fixing a bug for logging * Fix shared data_ptrs in altup inits * rope: adjust proj -> norm -> rope to preserve computation (#20) * rope: adjust proj -> norm -> rope to preserve computation * Removing some breaking language model fluff in ConditionalGeneration * Consolidate query_states transforms --------- Co-authored-by: Douglas Reid <21148125+douglas-reid@users.noreply.github.com> Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Vectorize the loops in AltUp (#19) * Vectorize the loops in AltUp * fix typo * Expanding to support batched inputs * remove extra debug script * Fix AltUp.forward --------- Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Add 'scale_shift=0.0, with_scale=True' to the final norm in TextModel * Convert norm to 1/sqrt (#21) * Convert norm to 1/sqrt * Scale shift change per Phil's rec * Adding default activation sparsity * Fixing 2B config in weights conversion script * Fixing RMSNorm parameters - adding scale_shift and with_scale * Correcting query pre-attention scaling * Adding query_rescale_scalar to text config * Adding layer_idx to MLP * Permafix for input_layernorm * Use 1/sqrt instead of rsqrt in DecoderLayer * Fix o_proj conversion * Conversion script update for vision encoder * Removing logging for debugging timm model * Fixing bugs in Gemma3nForConditionalGeneration for text generation * Generating the modeling_gemma3n.py file * Removing the addition of an erroneous line in the modeling file * Adding gemma3n text model to modeling_auto * Bugfix: Updating the interleaving of inputs_embeds and vision_embeds * Updating the modeling file with the latest bugfix changes * Updating models/auto for Gemma 3n * using AutoTokenizer in forward test * Adding processing_gemma3n.py * Gemma 3n configured for AutoModel. Conversion script updated. * Removing errant merge artifacts --------- Co-authored-by: Mayank Chaturvedi <imayank@google.com> Co-authored-by: Douglas Reid <douglas-reid@users.noreply.github.com> Co-authored-by: Douglas Reid <21148125+douglas-reid@users.noreply.github.com> Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com> Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com> * Removing errant debugging statements from Gemma 3 * Gemma3n audio model (#18) * testing utilities for numerics comparisons * Implement CumulativeGroupNorm and add to SubSampleConvProjection and SSCPConvBlock * Add audio version of forward script based on RyanMullins' implementation * Updating to match encoder tests. WIP: config question needs resolving * Updates to audio classes to enable end-to-end running * Removing vestigial classes, cleaning up print statements * Adding SiLU / Swish to audio conformer feed forward block * Shifted Gemma3p5Audio naming prefix to Gemma3NanoAudio * Adding outputs to audio test * Fixes to padding in SSCP and 1D convolution, align RMS Norm with wider model * Update forward test to load from local weights * Update conversion to process / output audio layers * Update __all__ to export audio encoder * AutoModel registration for Gemma 3n Audio * Use AutoModel for ConditionalGeneration.audio_tower * Fixing input_proj_linear transpose * Fixing Gemma3NanoAudioConformerAttention.post conversion * Fixing Gemma3NanoAudioSSCPConvBlock.conv weights conversion * Correcting indentation issue on Gemma3p5RMSNorm --------- Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Text + Vision Part 2 (#23) * Updates for ConditionalGeneration.get_image_features * Adding a WIP draft of image_processing_gemma3p5.py * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> * Modular conversion after github suggested change * Text + image gives good results * Fixing image size preset * Updating configs for the 2B variant in the conversion script * Using final generation config in conversion script --------- Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com> Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> * Audio Integration (#12) * initial commit of Gemma 3n scaffold * Fixing param pass through on Gemm3nRMSNorm * Adds Einsum layer to Gemma 3n * Updating EinsumLayer API * Undoing erroneous force push * Reverting RMSNorm to with_scale by default * Adds LAuReL to Gemma 3n * Adds AltUp to Gemma 3n * Adding Gemma 3n overall and text config with vision and audio config placeholders (#3) * Adding Gemma 3n text configs * Adding audio config placeholders * Adding a placeholder for vision configs * Updating MobileNetVisionConfig, inheriting TimmWrapperConfig * Updating text configs * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Removing altup configs to accept the suggested configs * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Updating altup config * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Addressing review comments and updating text configs * Adding a config for activation sparsity * Updating configs to pass through options to super class init and adjust some name prefixes * Updating laurel and altup with corrected config values * Normalizing sub_config initializers --------- Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Updating MLP with activation sparsity (#2) * Updating DecoderBlock for Gemma 3n (#3) * Initial Gemma3nTextModel (#4) NOTE: This implementation WILL CHANGE in the coming weeks, however, changes will be strictly additive and this will remain a suitable baseline for downstream implementations to reference. * Adding KV Cache Sharing * Adds Einsum layer to Gemma 3n * Updating EinsumLayer API * Refactored kv cache sharing in attention * Adding KVStore for cache sharing * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update modular Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update src/transformers/cache_utils.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Undoing erroneous force push * Reverting RMSNorm to with_scale by default * Adds LAuReL to Gemma 3n * Updating KV Cache Sharing implementation * Updating the q and k norm definitions in the attention module * Fixing name error for q,k,v RMS norm to use the right 3n module * Updating MLP with activation sparsity * Updating DecoderBlock for Gemma 3n * Updating kv cache sharing implementation with the use of a cache buffer and refactoring some lines of code * Isolating KV Cache logic to relevant components * Fixing logic error in Gemma3nAttention.forward * Refactoring caching contributions and fixing kv_store initialization * Simplifying Configs * Remove errant self from super init call * Bug fix in the Attention module - changing self.head_dim to config.head_dim * Bug fixes in the LaurelBlock and RMS Norm super init call * removing redundant code from a merge * Adding per_layer_inputs to TextModel * Adding preprocess embeddings with altup * Adds per-layer-to-single output and a host of TODOs * Integrating altup predict with the model workflow and other minor bug fixes * Using nn.Embedding temporarily for text model * It goes forward * Minor refactor of attention sparsity and RoPE initialization * Fixing duplicate rope_scaling param bug when loading from pretrained --------- Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com> Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> * Normalizing on altup_num_inputs config option * Adding audio encoder config * Adds high-level components for Audio Encoder * Implement uniform reducer for Audio Encoder * Adding placeholders for Conformer components in Audio Encoder * Adding placeholders for SubSampleConvProjection components in Audio Encoder * Adding SequenceLayer component placeholders * Implementing Gemma3nAudioEncoder with nn.Sequential * Implementing Gemma3nAudioSubSampleConvProjection with nn.Sequential * Implementing Conformer model with SequenceLayers * Use OrderedDict in nn.Sequential initializers * Implements sl.Residual in Torch with nn.Sequential and OrderedDict * Adopting a base SequenceLayer class with default forward() method * Implementing sl.GatedLinearUnit in Torch * Implementing sl.Swish in Torch * Implementing sl.ReLU in Torch * Implementing sl.Scale in Torch * Removing sl.Dropout after tree-shaking * Implementing sl.RMSNorm in Torch with fake shape * Implementing sl.GroupNorm in Torch * Implementing sl.Conv2d in Torch * Implementing sl.Dense in Torch * Removing sl.Delay layers, which act as pass-throughs * Connecting shapes to configs in initializers * Removing sl.Emit * Implementing sl.ExpandDims in Torch * Adding sl.GradientClipping to Torch * Implementing sl.DenseShaped in Torch * Implementing sl.LDPA in Torch * Removing unused sl.CombinedQKVProj class * Fixing erroneous type hint * Implemnenting sl.DepthwiseConv1D in Torch * Implementing sl.MaskInvalid in Torch * Fixes for initialization * Fixes for saving weights * Removing einsums per feedback from HF staff * Removing Sequence Layers idioms from audio encoder * Fixes for reviewer comments * Converting sl.Frontend to FeatureExtractor * Updates for ConditionalGeneration.get_image_features * Adding a WIP draft of image_processing_gemma3n.py * Update modular Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> * Modular conversion after github suggested change * Text + image gives good results * Fixing image size preset * Draft of audio data in chat template * Removing image processing. Using SigLIP instead. * Audio input going end-to-end * Fixing dtype issues in audio encoder * x-lib formatting consistency * Adding example data * Save preprocessor_config.json from conversion script * Instrumentaiton for debugging * Additional instrumentation for preprocessing debugging * Updates to preprocessor, padding; produces correct end-to-end results on sample * Tackling configuraiton TODOs * Start of feature extractor refatcor * Adds Numpy version of USM extractor, removes Torch version and dependencies * Fixing AltUp.correct coef permute * Supporting batches of single audio segment inputs * Docstrings updates for config * In-lining audio feature extraction * Adjustments to conversion script and smoke test script --------- Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com> Co-authored-by: pculliton <phillipculliton@gmail.com> * Gemma 3n renaming * Removing test data and utilities * Renaming test files * Gemma 3n refactor * Fix tokenizer config in conversion script * Address reviewer feedback * FeatureExtractor returns float32 by default * Adding basic tests for audio, and input name for audio encoder * Audio integration test, updates to model_id for other integration tests * Use scales for q and k norms (#26) * Update audio integration test to use HF dataset * Reviewer feedback * Expand embedding table to full vocab size in weights conversion * Mix-n-match MatFormers for Gemma 3n (#25) * Remove in-place operations (#30) * chore: removing inplace ops * remove [tensor] * n pattern * chore: reviewer feedback in AudioEncoder and AltUp * More grad clipping * Dynamo compatibility * fix: cache slicing error * chore: simplify shared kv cache slicing * chore: vision encoder rename in timm * fix: image processor do_normalize=False * fixup: style * chore: model_doc * fix: docs for code quality * chore: repo consistency * fix: RMSNorm in float as in prior Gemmas * fix: per_layer_inputs = None * chore: Gemma3nForCausalLM from Gemma3nForConditionalGeneration checkpoint * chore: repo consistency * Add initial unit tests for Gemma3nAudioFeatureExtractor (#27) * Add initial unit tests for Gemma3nAudioFeatureExtractor * Add basic unit tests for Gemma3nProcessor (#28) Co-authored-by: Douglas Reid <21148125+douglas-reid@users.noreply.github.com> * parameterize tests --------- Co-authored-by: Douglas Reid <21148125+douglas-reid@users.noreply.github.com> * chore: code style * fix: test cases * style and consistency * fix config in the test to be coherent with layer cache sharing * fix hidden states in tests and code * inits and mappings * fix modality prefixes * test order and prefixes * fix test exception * fix class order and reduce model size for faster tests * restore _checkpoint_conversion_mapping to load Caual from Conditional * fix config mapping! * fix: reviewer feedback --------- Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com> Co-authored-by: raushan <raushan@huggingface.co> Co-authored-by: Mayank Chaturvedi <imayank@google.com> Co-authored-by: Douglas Reid <douglas-reid@users.noreply.github.com> Co-authored-by: Douglas Reid <21148125+douglas-reid@users.noreply.github.com> Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com> Co-authored-by: pculliton <phillipculliton@gmail.com> Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com> Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com> * fix import test * add model args * auto_docstring * replace test path * consistency * skip tests for now * fix docstring for doc builder * skip unused attr --------- Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com> Co-authored-by: raushan <raushan@huggingface.co> Co-authored-by: Mayank Chaturvedi <imayank@google.com> Co-authored-by: Douglas Reid <douglas-reid@users.noreply.github.com> Co-authored-by: Douglas Reid <21148125+douglas-reid@users.noreply.github.com> Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com> Co-authored-by: pculliton <phillipculliton@gmail.com> Co-authored-by: Aritra Roy Gosthipaty <aritra.born2fly@gmail.com> Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com> Co-authored-by: Arthur <arthur.zucker@gmail.com>
This commit is contained in:
@@ -959,6 +959,8 @@
|
|||||||
title: FLAVA
|
title: FLAVA
|
||||||
- local: model_doc/gemma3
|
- local: model_doc/gemma3
|
||||||
title: Gemma3
|
title: Gemma3
|
||||||
|
- local: model_doc/gemma3n
|
||||||
|
title: Gemma3n
|
||||||
- local: model_doc/git
|
- local: model_doc/git
|
||||||
title: GIT
|
title: GIT
|
||||||
- local: model_doc/glm4v
|
- local: model_doc/glm4v
|
||||||
|
|||||||
204
docs/source/en/model_doc/gemma3n.md
Normal file
204
docs/source/en/model_doc/gemma3n.md
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
|
||||||
|
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||||
|
rendered properly in your Markdown viewer.
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
<div style="float: right;">
|
||||||
|
<div class="flex flex-wrap space-x-1">
|
||||||
|
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||||
|
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
# Gemma3n
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Gemma3n is a multimodal model with pretrained and instruction-tuned variants, available in E4B and E2B sizes. While
|
||||||
|
large portions of the language model architecture are shared with prior Gemma releases, there are many new additions in
|
||||||
|
this model, including [Alternating Updates][altup] (AltUp), [Learned Augmented Residual Layer][laurel] (LAuReL),
|
||||||
|
[MatFormer][matformer], Per-Layer Embeddings (PLE), activation sparsity, and KV cache sharing. The language model uses
|
||||||
|
a similar attention pattern to [Gemma 3](./gemma3.md) with alternating 4 local sliding window self-attention layers for
|
||||||
|
every global self-attention layer with a maximum context length of 32k tokens. Gemma 3n introduces
|
||||||
|
[MobileNet v5][mobilenetv5] as the vision encoder, using a default resolution of 768x768 pixels, and adds a
|
||||||
|
[Universal Speech Model][usm] (USM) as the audio encoder.
|
||||||
|
|
||||||
|
The instruction-tuned variant was post-trained with knowledge distillation and reinforcement learning.
|
||||||
|
|
||||||
|
You can find all the original Gemma 3n checkpoints under the [Gemma 3n][gemma3n-collection] release.
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> Click on the Gemma 3n models in the right sidebar for more examples of how to apply Gemma to different vision, audio,
|
||||||
|
> and language tasks.
|
||||||
|
|
||||||
|
The example below demonstrates how to generate text based on an image with [`Pipeline`] or the [`AutoModel`] class.
|
||||||
|
|
||||||
|
<hfoptions id="usage">
|
||||||
|
<hfoption id="Pipeline">
|
||||||
|
|
||||||
|
```py
|
||||||
|
import torch
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
|
pipeline = pipeline(
|
||||||
|
task="image-text-to-text",
|
||||||
|
model="google/gemma-3n-e4b",
|
||||||
|
device=0,
|
||||||
|
torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
pipeline(
|
||||||
|
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
|
||||||
|
text="<start_of_image> What is shown in this image?"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="AutoModel">
|
||||||
|
|
||||||
|
```py
|
||||||
|
import torch
|
||||||
|
from transformers import AutoProcessor, Gemma3nForConditionalGeneration
|
||||||
|
|
||||||
|
model = Gemma3nForConditionalGeneration.from_pretrained(
|
||||||
|
"google/gemma-3n-e4b-it",
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map="auto",
|
||||||
|
attn_implementation="sdpa"
|
||||||
|
)
|
||||||
|
processor = AutoProcessor.from_pretrained(
|
||||||
|
"google/gemma-3n-e4b-it",
|
||||||
|
padding_side="left"
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "You are a helpful assistant."}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user", "content": [
|
||||||
|
{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
]
|
||||||
|
inputs = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
add_generation_prompt=True,
|
||||||
|
).to("cuda")
|
||||||
|
|
||||||
|
output = model.generate(**inputs, max_new_tokens=50, cache_implementation="static")
|
||||||
|
print(processor.decode(output[0], skip_special_tokens=True))
|
||||||
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="transformers CLI">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
echo -e "Plants create energy through a process known as" | transformers run --task text-generation --model google/gemma-3n-e2b --device 0
|
||||||
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
</hfoptions>
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- Use [`Gemma3nForConditionalGeneration`] for image-audio-and-text, image-and-text, image-and-audio, audio-and-text,
|
||||||
|
image-only and aduio-only inputs.
|
||||||
|
- Gemma 3n supports multiple images per input, but make sure the images are correctly batched before passing them to
|
||||||
|
the processor. Each batch should be a list of one or more images.
|
||||||
|
|
||||||
|
```py
|
||||||
|
url_cow = "https://media.istockphoto.com/id/1192867753/photo/cow-in-berchida-beach-siniscola.jpg?s=612x612&w=0&k=20&c=v0hjjniwsMNfJSuKWZuIn8pssmD5h5bSN1peBd1CmH4="
|
||||||
|
url_cat = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||||
|
|
||||||
|
messages =[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "You are a helpful assistant."}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "url": url_cow},
|
||||||
|
{"type": "image", "url": url_cat},
|
||||||
|
{"type": "text", "text": "Which image is cuter?"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
]
|
||||||
|
```
|
||||||
|
- Text passed to the processor should have a `<image_soft_token>` token wherever an image should be inserted.
|
||||||
|
- Gemma 3n accept at most one target audio clip per input, though multiple audio clips can be provided in few-shot
|
||||||
|
prompts, for example.
|
||||||
|
- Text passed to the processor should have a `<audio_soft_token>` token wherever an audio clip should be inserted.
|
||||||
|
- The processor has its own [`~ProcessorMixin.apply_chat_template`] method to convert chat messages to model inputs.
|
||||||
|
|
||||||
|
## Gemma3nAudioFeatureExtractor
|
||||||
|
|
||||||
|
[[autodoc]] Gemma3nAudioFeatureExtractor
|
||||||
|
|
||||||
|
## Gemma3nProcessor
|
||||||
|
|
||||||
|
[[autodoc]] Gemma3nProcessor
|
||||||
|
|
||||||
|
## Gemma3nTextConfig
|
||||||
|
|
||||||
|
[[autodoc]] Gemma3nTextConfig
|
||||||
|
|
||||||
|
## Gemma3nVisionConfig
|
||||||
|
|
||||||
|
[[autodoc]] Gemma3nVisionConfig
|
||||||
|
|
||||||
|
## Gemma3nAudioConfig
|
||||||
|
|
||||||
|
[[autodoc]] Gemma3nAudioConfig
|
||||||
|
|
||||||
|
## Gemma3nConfig
|
||||||
|
|
||||||
|
[[autodoc]] Gemma3nConfig
|
||||||
|
|
||||||
|
## Gemma3nTextModel
|
||||||
|
|
||||||
|
[[autodoc]] Gemma3nTextModel
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## Gemma3nModel
|
||||||
|
|
||||||
|
[[autodoc]] Gemma3nModel
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## Gemma3nForCausalLM
|
||||||
|
|
||||||
|
[[autodoc]] Gemma3nForCausalLM
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## Gemma3nForConditionalGeneration
|
||||||
|
|
||||||
|
[[autodoc]] Gemma3nForConditionalGeneration
|
||||||
|
- forward
|
||||||
|
|
||||||
|
[altup]: https://proceedings.neurips.cc/paper_files/paper/2023/hash/f2059277ac6ce66e7e5543001afa8bb5-Abstract-Conference.html
|
||||||
|
[attention-mask-viz]: https://github.com/huggingface/transformers/blob/beb9b5b02246b9b7ee81ddf938f93f44cfeaad19/src/transformers/utils/attention_visualizer.py#L139
|
||||||
|
[gemma3n-collection]: https://huggingface.co/collections/google/gemma-3n
|
||||||
|
[laurel]: https://arxiv.org/abs/2411.07501
|
||||||
|
[matformer]: https://arxiv.org/abs/2310.07707
|
||||||
|
[usm]: https://arxiv.org/abs/2303.01037
|
||||||
@@ -140,6 +140,10 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
|||||||
("gemma2", "Gemma2Config"),
|
("gemma2", "Gemma2Config"),
|
||||||
("gemma3", "Gemma3Config"),
|
("gemma3", "Gemma3Config"),
|
||||||
("gemma3_text", "Gemma3TextConfig"),
|
("gemma3_text", "Gemma3TextConfig"),
|
||||||
|
("gemma3n", "Gemma3nConfig"),
|
||||||
|
("gemma3n_audio", "Gemma3nAudioConfig"),
|
||||||
|
("gemma3n_text", "Gemma3nTextConfig"),
|
||||||
|
("gemma3n_vision", "Gemma3nVisionConfig"),
|
||||||
("git", "GitConfig"),
|
("git", "GitConfig"),
|
||||||
("glm", "GlmConfig"),
|
("glm", "GlmConfig"),
|
||||||
("glm4", "Glm4Config"),
|
("glm4", "Glm4Config"),
|
||||||
@@ -518,6 +522,10 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
|||||||
("gemma2", "Gemma2"),
|
("gemma2", "Gemma2"),
|
||||||
("gemma3", "Gemma3ForConditionalGeneration"),
|
("gemma3", "Gemma3ForConditionalGeneration"),
|
||||||
("gemma3_text", "Gemma3ForCausalLM"),
|
("gemma3_text", "Gemma3ForCausalLM"),
|
||||||
|
("gemma3n", "Gemma3nForConditionalGeneration"),
|
||||||
|
("gemma3n_audio", "Gemma3nAudioEncoder"),
|
||||||
|
("gemma3n_text", "Gemma3nForCausalLM"),
|
||||||
|
("gemma3n_vision", "TimmWrapperModel"),
|
||||||
("git", "GIT"),
|
("git", "GIT"),
|
||||||
("glm", "GLM"),
|
("glm", "GLM"),
|
||||||
("glm4", "GLM4"),
|
("glm4", "GLM4"),
|
||||||
@@ -839,6 +847,9 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str](
|
|||||||
("clip_text_model", "clip"),
|
("clip_text_model", "clip"),
|
||||||
("aria_text", "aria"),
|
("aria_text", "aria"),
|
||||||
("gemma3_text", "gemma3"),
|
("gemma3_text", "gemma3"),
|
||||||
|
("gemma3n_audio", "gemma3n"),
|
||||||
|
("gemma3n_text", "gemma3n"),
|
||||||
|
("gemma3n_vision", "gemma3n"),
|
||||||
("glm4v_text", "glm4v"),
|
("glm4v_text", "glm4v"),
|
||||||
("idefics3_vision", "idefics3"),
|
("idefics3_vision", "idefics3"),
|
||||||
("siglip_vision_model", "siglip"),
|
("siglip_vision_model", "siglip"),
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
|||||||
("dpt", "DPTFeatureExtractor"),
|
("dpt", "DPTFeatureExtractor"),
|
||||||
("encodec", "EncodecFeatureExtractor"),
|
("encodec", "EncodecFeatureExtractor"),
|
||||||
("flava", "FlavaFeatureExtractor"),
|
("flava", "FlavaFeatureExtractor"),
|
||||||
|
("gemma3n", "Gemma3nAudioFeatureExtractor"),
|
||||||
("glpn", "GLPNFeatureExtractor"),
|
("glpn", "GLPNFeatureExtractor"),
|
||||||
("granite_speech", "GraniteSpeechFeatureExtractor"),
|
("granite_speech", "GraniteSpeechFeatureExtractor"),
|
||||||
("groupvit", "CLIPFeatureExtractor"),
|
("groupvit", "CLIPFeatureExtractor"),
|
||||||
|
|||||||
@@ -88,6 +88,7 @@ else:
|
|||||||
("focalnet", ("BitImageProcessor", "BitImageProcessorFast")),
|
("focalnet", ("BitImageProcessor", "BitImageProcessorFast")),
|
||||||
("fuyu", ("FuyuImageProcessor",)),
|
("fuyu", ("FuyuImageProcessor",)),
|
||||||
("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
|
("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
|
||||||
|
("gemma3n", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
|
||||||
("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||||
("glm4v", ("Glm4vImageProcessor", "Glm4vImageProcessorFast")),
|
("glm4v", ("Glm4vImageProcessor", "Glm4vImageProcessorFast")),
|
||||||
("glpn", ("GLPNImageProcessor",)),
|
("glpn", ("GLPNImageProcessor",)),
|
||||||
|
|||||||
@@ -132,6 +132,10 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("gemma2", "Gemma2Model"),
|
("gemma2", "Gemma2Model"),
|
||||||
("gemma3", "Gemma3Model"),
|
("gemma3", "Gemma3Model"),
|
||||||
("gemma3_text", "Gemma3TextModel"),
|
("gemma3_text", "Gemma3TextModel"),
|
||||||
|
("gemma3n", "Gemma3nModel"),
|
||||||
|
("gemma3n_audio", "Gemma3nAudioEncoder"),
|
||||||
|
("gemma3n_text", "Gemma3nTextModel"),
|
||||||
|
("gemma3n_vision", "TimmWrapperModel"),
|
||||||
("git", "GitModel"),
|
("git", "GitModel"),
|
||||||
("glm", "GlmModel"),
|
("glm", "GlmModel"),
|
||||||
("glm4", "Glm4Model"),
|
("glm4", "Glm4Model"),
|
||||||
@@ -583,6 +587,8 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|||||||
("gemma2", "Gemma2ForCausalLM"),
|
("gemma2", "Gemma2ForCausalLM"),
|
||||||
("gemma3", "Gemma3ForConditionalGeneration"),
|
("gemma3", "Gemma3ForConditionalGeneration"),
|
||||||
("gemma3_text", "Gemma3ForCausalLM"),
|
("gemma3_text", "Gemma3ForCausalLM"),
|
||||||
|
("gemma3n", "Gemma3nForConditionalGeneration"),
|
||||||
|
("gemma3n_text", "Gemma3nForCausalLM"),
|
||||||
("git", "GitForCausalLM"),
|
("git", "GitForCausalLM"),
|
||||||
("glm", "GlmForCausalLM"),
|
("glm", "GlmForCausalLM"),
|
||||||
("glm4", "Glm4ForCausalLM"),
|
("glm4", "Glm4ForCausalLM"),
|
||||||
@@ -906,6 +912,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
|
|||||||
("emu3", "Emu3ForConditionalGeneration"),
|
("emu3", "Emu3ForConditionalGeneration"),
|
||||||
("fuyu", "FuyuForCausalLM"),
|
("fuyu", "FuyuForCausalLM"),
|
||||||
("gemma3", "Gemma3ForConditionalGeneration"),
|
("gemma3", "Gemma3ForConditionalGeneration"),
|
||||||
|
("gemma3n", "Gemma3nForConditionalGeneration"),
|
||||||
("git", "GitForCausalLM"),
|
("git", "GitForCausalLM"),
|
||||||
("glm4v", "Glm4vForConditionalGeneration"),
|
("glm4v", "Glm4vForConditionalGeneration"),
|
||||||
("got_ocr2", "GotOcr2ForConditionalGeneration"),
|
("got_ocr2", "GotOcr2ForConditionalGeneration"),
|
||||||
|
|||||||
@@ -66,6 +66,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
|||||||
("flava", "FlavaProcessor"),
|
("flava", "FlavaProcessor"),
|
||||||
("fuyu", "FuyuProcessor"),
|
("fuyu", "FuyuProcessor"),
|
||||||
("gemma3", "Gemma3Processor"),
|
("gemma3", "Gemma3Processor"),
|
||||||
|
("gemma3n", "Gemma3nProcessor"),
|
||||||
("git", "GitProcessor"),
|
("git", "GitProcessor"),
|
||||||
("glm4v", "Glm4vProcessor"),
|
("glm4v", "Glm4vProcessor"),
|
||||||
("got_ocr2", "GotOcr2Processor"),
|
("got_ocr2", "GotOcr2Processor"),
|
||||||
|
|||||||
@@ -236,6 +236,20 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
|||||||
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
"gemma3n",
|
||||||
|
(
|
||||||
|
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
||||||
|
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"gemma3n_text",
|
||||||
|
(
|
||||||
|
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
||||||
|
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
||||||
|
),
|
||||||
|
),
|
||||||
("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("glm4", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
("glm4", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
|
|||||||
29
src/transformers/models/gemma3n/__init__.py
Normal file
29
src/transformers/models/gemma3n/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import _LazyModule
|
||||||
|
from ...utils.import_utils import define_import_structure
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .configuration_gemma3n import *
|
||||||
|
from .feature_extraction_gemma3n import *
|
||||||
|
from .modeling_gemma3n import *
|
||||||
|
from .processing_gemma3n import *
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
_file = globals()["__file__"]
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
||||||
680
src/transformers/models/gemma3n/configuration_gemma3n.py
Normal file
680
src/transformers/models/gemma3n/configuration_gemma3n.py
Normal file
@@ -0,0 +1,680 @@
|
|||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# This file was automatically generated from src/transformers/models/gemma3n/modular_gemma3n.py.
|
||||||
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
|
# modular_gemma3n.py file directly. One of our CI enforces this.
|
||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig, layer_type_validation
|
||||||
|
from ...modeling_rope_utils import rope_config_validation
|
||||||
|
from ...utils import is_timm_available, logging, requires_backends
|
||||||
|
|
||||||
|
|
||||||
|
if is_timm_available():
|
||||||
|
from timm.data import ImageNetInfo, infer_imagenet_subset
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3nTextConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`Gemma3nTextModel`]. It is used to instantiate an
|
||||||
|
Gemma3nTextModel model according to the specified arguments, defining the model architecture. Instantiating a
|
||||||
|
configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B, e.g.
|
||||||
|
[google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
|
||||||
|
|
||||||
|
Configuration objects that inherit from [`Gemma3nTextConfig`] and can be used to control the model outputs. Read
|
||||||
|
the documentation from [`Gemma3nTextConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 262400):
|
||||||
|
Vocabulary size of the Gemma3nText model. Defines the number of different tokens that can be represented by
|
||||||
|
the `inputs_ids` passed when calling [`Gemma3nTextModel`]
|
||||||
|
vocab_size_per_layer_input (`int`, *optional*, defaults to 262144):
|
||||||
|
Vocabulary size of the per-layer text embeddings that augment the standard embeddings.
|
||||||
|
hidden_size (`int`, *optional*, defaults to 2048):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
hidden_size_per_layer_input (`int`, *optional*, defaults to 256):
|
||||||
|
Dimension of the hidden representations for per-layer emebeddings.
|
||||||
|
intermediate_size (`int` or `Sequence[int]`, *optional*, defaults to 16384):
|
||||||
|
Dimension of the MLP representations. MatFormer configurations may wish to provide a sequence of integers
|
||||||
|
to account for vairable intermediate_size values across layers. In such cases,
|
||||||
|
`len(intermediate_size) == num_hidden_layers`.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 35):
|
||||||
|
Number of hidden layers in the Transformer decoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 8):
|
||||||
|
Number of attention heads for each attention layer in the Transformer decoder.
|
||||||
|
num_key_value_heads (`int`, *optional*, defaults to 2):
|
||||||
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||||
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||||
|
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||||
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||||
|
by meanpooling all the original heads within that group. For more details checkout this
|
||||||
|
[paper](https://arxiv.org/pdf/2305.13245.pdf). If not specified, will default to `num_attention_heads`.
|
||||||
|
head_dim (`int`, *optional*, defaults to 256):
|
||||||
|
The attention head dimension.
|
||||||
|
hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||||
|
The non-linear activation function (function or string) in the decoder. Will default to
|
||||||
|
`"gelu_pytorch_tanh"` if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"`
|
||||||
|
activation function.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||||
|
The maximum sequence length that this model might ever be used with.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
|
The epsilon used by the rms normalization layers.
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if `config.is_decoder=True`.
|
||||||
|
pad_token_id (`int`, *optional*, defaults to 0):
|
||||||
|
Padding token id.
|
||||||
|
eos_token_id (`int`, *optional*, defaults to 1):
|
||||||
|
End of stream token id.
|
||||||
|
bos_token_id (`int`, *optional*, defaults to 2):
|
||||||
|
Beginning of stream token id.
|
||||||
|
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
||||||
|
The base period of the RoPE embeddings.
|
||||||
|
rope_scaling (`Dict`, *optional*):
|
||||||
|
Dictionary containing the scaling configuration for the RoPE embeddings used in gloabl attention.
|
||||||
|
NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we
|
||||||
|
recommend you to update this value accordingly.
|
||||||
|
Expected contents:
|
||||||
|
`rope_type` (`str`):
|
||||||
|
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||||
|
'llama3'], with 'default' being the original RoPE implementation.
|
||||||
|
`factor` (`float`, *optional*):
|
||||||
|
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||||
|
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||||
|
original maximum pre-trained length.
|
||||||
|
`original_max_position_embeddings` (`int`, *optional*):
|
||||||
|
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||||
|
pretraining.
|
||||||
|
`attention_factor` (`float`, *optional*):
|
||||||
|
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||||
|
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||||
|
`factor` field to infer the suggested value.
|
||||||
|
`beta_fast` (`float`, *optional*):
|
||||||
|
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||||
|
ramp function. If unspecified, it defaults to 32.
|
||||||
|
`beta_slow` (`float`, *optional*):
|
||||||
|
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||||
|
ramp function. If unspecified, it defaults to 1.
|
||||||
|
`short_factor` (`List[float]`, *optional*):
|
||||||
|
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||||
|
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||||
|
size divided by the number of attention heads divided by 2
|
||||||
|
`long_factor` (`List[float]`, *optional*):
|
||||||
|
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||||
|
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||||
|
size divided by the number of attention heads divided by 2
|
||||||
|
`low_freq_factor` (`float`, *optional*):
|
||||||
|
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||||
|
`high_freq_factor` (`float`, *optional*):
|
||||||
|
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||||
|
rope_local_base_freq (float, *optional*, defaults to 10000.0):
|
||||||
|
The base period of the RoPE embeddings for local attention.
|
||||||
|
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||||
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
sliding_window (`int`, *optional*, defaults to 512):
|
||||||
|
This is the size of the sliding window used by local attention layers.
|
||||||
|
layer_types (`Optional`, *optional*):
|
||||||
|
A sequence of strings defining the attention type for that layer as either "sliding_attention" or
|
||||||
|
"full_attention". If not provided, `layer_types` will de inferred from `num_hidden_layers` using a pattern
|
||||||
|
of four "sliding_attention" layers followed one "full_attention". The last layer in the model should always
|
||||||
|
be a "full_attention" layer.
|
||||||
|
final_logit_softcapping (`float`, *optional*, defaults to 30.0):
|
||||||
|
Scaling factor when applying tanh softcapping on the logits.
|
||||||
|
altup_active_idx (`int`, *optional*, defaults to 0):
|
||||||
|
The index of the prediction from which AltUp will compute additional predictions or correct
|
||||||
|
altup_coef_clip (`float`, *optional*, defaults to 120.0):
|
||||||
|
The maximum amplitude of an AltUp prediction or correction coeficient weight.
|
||||||
|
altup_correct_scale (`bool`, *optional*, defaults to `True`):
|
||||||
|
If True, apply the `AltUp.correct_output_scale` to the corrected prediction at `altup_active_idx`.
|
||||||
|
altup_num_inputs (`int`, *optional*, defaults to 4):
|
||||||
|
The number of predictions that AltUp should be make given the input sequence.
|
||||||
|
num_kv_shared_layers (`int`, *optional*, defaults to 15):
|
||||||
|
The number of layer that share KV cache values. During the forward pass, the last `num_kv_shared_layers`
|
||||||
|
layers in the model "share" the KV values in that each local and global layer in this range uses the KV
|
||||||
|
cache values computed for the last local or global layer, respectively, before entering this range. The
|
||||||
|
value should be `num_kv_shared_layers` should be a scalar of `sliding_window_pattern`.
|
||||||
|
laurel_rank (int, *optional*, defaults to 64):
|
||||||
|
The intermediate size for the linear projections in the Learned Augmented Residual Layer.
|
||||||
|
activation_sparsity_pattern (Sequence[float], *optional*, defaults to `(0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)`):
|
||||||
|
The sparsity factor used to extract the top-k activations for a given layer. The provided Sequence must
|
||||||
|
explicitly provide a sparsity value for each layer in the model.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import Gemma3nTextModel, Gemma3nTextConfig
|
||||||
|
|
||||||
|
>>> # Initializing a Gemma3nText gemma3n_text-E4B style configuration
|
||||||
|
>>> configuration = Gemma3nTextConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model from the gemma3n_text-E4B style configuration
|
||||||
|
>>> model = Gemma3nTextModel(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_type = "gemma3n_text"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
base_model_tp_plan = {
|
||||||
|
"layers.*.self_attn.q_proj": "colwise",
|
||||||
|
"layers.*.self_attn.k_proj": "colwise",
|
||||||
|
"layers.*.self_attn.v_proj": "colwise",
|
||||||
|
"layers.*.self_attn.o_proj": "rowwise",
|
||||||
|
"layers.*.mlp.gate_proj": "colwise",
|
||||||
|
"layers.*.mlp.up_proj": "colwise",
|
||||||
|
"layers.*.mlp.down_proj": "rowwise",
|
||||||
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size: int = 262_400,
|
||||||
|
vocab_size_per_layer_input: int = 262_144,
|
||||||
|
hidden_size: int = 2048,
|
||||||
|
hidden_size_per_layer_input: int = 256,
|
||||||
|
intermediate_size: Union[int, Sequence[int]] = 16_384,
|
||||||
|
num_hidden_layers: int = 35,
|
||||||
|
num_attention_heads: int = 8,
|
||||||
|
num_key_value_heads: int = 2,
|
||||||
|
head_dim: int = 256,
|
||||||
|
hidden_activation: str = "gelu_pytorch_tanh",
|
||||||
|
max_position_embeddings: int = 32_768,
|
||||||
|
initializer_range: float = 0.02,
|
||||||
|
rms_norm_eps: float = 1e-6,
|
||||||
|
use_cache: bool = True,
|
||||||
|
pad_token_id: int = 0,
|
||||||
|
eos_token_id: int = 1,
|
||||||
|
bos_token_id: int = 2,
|
||||||
|
rope_theta: float = 1_000_000.0,
|
||||||
|
rope_scaling: Optional[dict[str, Any]] = None,
|
||||||
|
rope_local_base_freq: float = 10_000.0,
|
||||||
|
attention_bias: bool = False,
|
||||||
|
attention_dropout: float = 0.0,
|
||||||
|
sliding_window: int = 512,
|
||||||
|
layer_types: Optional[Sequence[str]] = None,
|
||||||
|
final_logit_softcapping: float = 30.0,
|
||||||
|
altup_active_idx: int = 0,
|
||||||
|
altup_coef_clip: float = 120.0,
|
||||||
|
altup_correct_scale: bool = True,
|
||||||
|
altup_num_inputs: int = 4,
|
||||||
|
num_kv_shared_layers: int = 15,
|
||||||
|
laurel_rank: int = 64,
|
||||||
|
activation_sparsity_pattern: Optional[Union[float, Sequence[float]]] = (0.95,) * 10 + (0.0,) * 25,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(intermediate_size, Sequence) and (intsize_len := len(intermediate_size)) != num_hidden_layers:
|
||||||
|
raise ValueError(
|
||||||
|
"intermediate_size must have an explicit intermediate size for every layer or one for all layers. "
|
||||||
|
f"Expected {num_hidden_layers} values but got {intsize_len}."
|
||||||
|
)
|
||||||
|
elif not isinstance(intermediate_size, Sequence):
|
||||||
|
intermediate_size = [intermediate_size] * num_hidden_layers
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.vocab_size_per_layer_input = vocab_size_per_layer_input
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.hidden_activation = hidden_activation
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
self.final_logit_softcapping = final_logit_softcapping
|
||||||
|
self.layer_types = layer_types
|
||||||
|
|
||||||
|
self.rope_local_base_freq = rope_local_base_freq
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
rope_config_validation(self)
|
||||||
|
|
||||||
|
if layer_types is None:
|
||||||
|
self.layer_types = [
|
||||||
|
"full_attention" if i % 5 == 0 else "sliding_attention" for i in range(self.num_hidden_layers)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
self.layer_types = layer_types
|
||||||
|
|
||||||
|
layer_type_validation(self.layer_types)
|
||||||
|
|
||||||
|
self.hidden_size_per_layer_input = hidden_size_per_layer_input
|
||||||
|
self.num_kv_shared_layers = num_kv_shared_layers
|
||||||
|
|
||||||
|
self.altup_active_idx = altup_active_idx
|
||||||
|
self.altup_coef_clip = altup_coef_clip
|
||||||
|
self.altup_correct_scale = altup_correct_scale
|
||||||
|
self.altup_num_inputs = altup_num_inputs
|
||||||
|
|
||||||
|
self.laurel_rank = laurel_rank
|
||||||
|
|
||||||
|
if activation_sparsity_pattern is None:
|
||||||
|
activation_sparsity_pattern = [0.0] * num_hidden_layers
|
||||||
|
|
||||||
|
if (len_asp := len(activation_sparsity_pattern)) != num_hidden_layers:
|
||||||
|
raise ValueError(
|
||||||
|
"activation_sparsity_pattern must have an explicit activation sparsity value for every layer."
|
||||||
|
f"Expected {num_hidden_layers} values but got {len_asp}."
|
||||||
|
)
|
||||||
|
self.activation_sparsity_pattern = activation_sparsity_pattern
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3nAudioConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`Gemma3nAudioEncoder`], based on Gogole's
|
||||||
|
[Universal Speech Model](). It is used to instantiate an Gemma3nAudioEncoder model according to the specified
|
||||||
|
arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar
|
||||||
|
configuration to that of the Gemma 3n E4B, e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
|
||||||
|
|
||||||
|
Configuration objects that inherit from [`Gemma3nAudioConfig`] and can be used to control the model outputs. Read
|
||||||
|
the documentation from [`Gemma3nAudioConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 128):
|
||||||
|
Vocabulary size of the additional hard-token embeddings for audio model. These augment the embeddings
|
||||||
|
included in the `Gemma3nTextModel` to provide, e.g., the end of audio and audio soft token placeholder
|
||||||
|
tokens when converting `input_ids` to embeddings in the `Gemma3nForConditionalGeneration` model.
|
||||||
|
vocab_offset (`int`, *optional*, defaults to 262272):
|
||||||
|
Offset between the tokenizer vocab index for the token ids embedded by `Gemma3nMultimodalEmbedder` and the
|
||||||
|
0-indexed `Gemma3nMultimodalEmbedder.embedding` table.
|
||||||
|
input_feat_size (`int`, *optional*, defaults to 128):
|
||||||
|
The number of channels in each mel-spectrogram frame.
|
||||||
|
hidden_size (`int`, *optional*, defaults to 1536):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
|
The epsilon used by the rms normalization layers.
|
||||||
|
gradient_clipping (`float`, *optional*, defaults to 10000000000.0):
|
||||||
|
Clipping value used to stablize extremely large gradient values.
|
||||||
|
conf_attention_chunk_size (`int`, *optional*, defaults to 12):
|
||||||
|
The sub-sequence size for local attention processing inside the Conformer ("conf") section of the
|
||||||
|
Universal Speech Model.
|
||||||
|
conf_attention_context_left (`int`, *optional*, defaults to 13):
|
||||||
|
The left context size of the local attention inside the Conformer ("conf") section of the
|
||||||
|
Universal Speech Model.
|
||||||
|
conf_attention_context_right (`int`, *optional*, defaults to 0):
|
||||||
|
The right context size of the local attention inside the Conformer ("conf") section of the
|
||||||
|
Universal Speech Model.
|
||||||
|
conf_attention_logit_cap (`float`, *optional*, defaults to 50.0):
|
||||||
|
Logit cap applied during local attention inside the Conformer ("conf") section of the
|
||||||
|
Universal Speech Model.
|
||||||
|
conf_num_attention_heads (`int`, *optional*, defaults to 8):
|
||||||
|
The number of attention heads in local attention inside the Conformer ("conf") section of the
|
||||||
|
Universal Speech Model.
|
||||||
|
conf_num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||||
|
The number of layers that use local attention inside the Conformer ("conf") section of the
|
||||||
|
Universal Speech Model.
|
||||||
|
conf_conv_kernel_size (`int`, *optional*, defaults to 5):
|
||||||
|
Convolution kernel size for the conformer block inside the Conformer ("conf") section of the
|
||||||
|
Universal Speech Model.
|
||||||
|
conf_reduction_factor (`int`, *optional*, defaults to 4):
|
||||||
|
Reduction factor used in the conformer block inside the Conformer ("conf") section of the
|
||||||
|
Universal Speech Model.
|
||||||
|
conf_residual_weight (`float`, *optional*, defaults to 0.5):
|
||||||
|
Residual connection weight inside the Conformer ("conf") section of the
|
||||||
|
Universal Speech Model.
|
||||||
|
sscp_conv_channel_size (`tuple(int, int)`, *optional*, defaults to `(128, 32)`):
|
||||||
|
The channel sizes for the first and second convolutional layers in the Sub-sample Convolution Projection
|
||||||
|
("sscp") section of the Universal Speech Model.
|
||||||
|
sscp_conv_group_norm_eps (`float`, *optional*, defaults to 0.001):
|
||||||
|
Epsilon used in group normalization in the subsample convolution projection in the Sub-sample Convolution
|
||||||
|
Projection ("sscp") section of the Universal Speech Model.
|
||||||
|
sscp_conv_kernel_size (`tuple(tuple(int, int), tuple(int, int))`, *optional*, defaults to `((3, 3), (3, 3))`):
|
||||||
|
Kernel sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample
|
||||||
|
Convolution Projection ("sscp") section of the Universal Speech Model. The kernel sizes are specified as a
|
||||||
|
tuple of height and width for each layer, where the height corresponds to the time dimension and the width
|
||||||
|
corresponds to the frequency dimension.
|
||||||
|
sscp_conv_stride_size (`tuple(tuple(int, int), tuple(int, int))`, *optional*, defaults to `((2, 2), (2, 2))`):
|
||||||
|
Stride sizes of the two convolutional layers in the subsample convolution projection in the Sub-sample
|
||||||
|
Convolution Projection ("sscp") section of the Universal Speech Model. The stride sizes are specified as a
|
||||||
|
tuple of height and width for each layer, where the height corresponds to the time dimension and the width
|
||||||
|
corresponds to the frequency dimension.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import Gemma3nAudioConfig, Gemma3nAudioEncoder
|
||||||
|
|
||||||
|
>>> # Initializing a Gemma3nAudioEncoder gemma3n_audio-E4B-style configuration
|
||||||
|
>>> configuration = Gemma3nAudioConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model from the gemma3n_audio-E4B style configuration
|
||||||
|
>>> model = Gemma3nAudioEncoder(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_type = "gemma3n_audio"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size: int = 128,
|
||||||
|
vocab_offset: int = 262_144 + 128, # text vocab size + vision vocab size
|
||||||
|
input_feat_size: int = 128,
|
||||||
|
hidden_size: int = 1536,
|
||||||
|
rms_norm_eps: float = 1e-6,
|
||||||
|
gradient_clipping: float = 10_000_000_000.0,
|
||||||
|
conf_attention_chunk_size: int = 12,
|
||||||
|
conf_attention_context_left: int = 13,
|
||||||
|
conf_attention_context_right: int = 0,
|
||||||
|
conf_attention_logit_cap: float = 50.0,
|
||||||
|
conf_num_attention_heads: int = 8,
|
||||||
|
conf_num_hidden_layers: int = 12,
|
||||||
|
conf_conv_kernel_size: int = 5,
|
||||||
|
conf_reduction_factor: int = 4,
|
||||||
|
conf_residual_weight: float = 0.5,
|
||||||
|
sscp_conv_channel_size: tuple[int, int] = (128, 32),
|
||||||
|
sscp_conv_group_norm_eps: float = 1e-3,
|
||||||
|
sscp_conv_kernel_size: tuple[tuple[int, int], tuple[int, int]] = (
|
||||||
|
(3, 3),
|
||||||
|
(3, 3),
|
||||||
|
),
|
||||||
|
sscp_conv_stride_size: tuple[tuple[int, int], tuple[int, int]] = (
|
||||||
|
(2, 2),
|
||||||
|
(2, 2),
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.input_feat_size = input_feat_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.vocab_offset = vocab_offset
|
||||||
|
self.gradient_clipping = gradient_clipping
|
||||||
|
self.conf_attention_chunk_size = conf_attention_chunk_size
|
||||||
|
self.conf_attention_context_left = conf_attention_context_left
|
||||||
|
self.conf_attention_context_right = conf_attention_context_right
|
||||||
|
self.conf_attention_logit_cap = conf_attention_logit_cap
|
||||||
|
self.conf_num_attention_heads = conf_num_attention_heads
|
||||||
|
self.conf_num_hidden_layers = conf_num_hidden_layers
|
||||||
|
self.conf_conv_kernel_size = conf_conv_kernel_size
|
||||||
|
self.conf_reduction_factor = conf_reduction_factor
|
||||||
|
self.conf_residual_weight = conf_residual_weight
|
||||||
|
self.sscp_conv_channel_size = sscp_conv_channel_size
|
||||||
|
self.sscp_conv_group_norm_eps = sscp_conv_group_norm_eps
|
||||||
|
self.sscp_conv_kernel_size = sscp_conv_kernel_size
|
||||||
|
self.sscp_conv_stride_size = sscp_conv_stride_size
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3nVisionConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration for a timm backbone [`TimmWrapper`]. It is used to
|
||||||
|
instantiate an timm model model according to the specified arguments, defining the model architecture.
|
||||||
|
Instantiating a configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B
|
||||||
|
vision tower, e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
|
||||||
|
|
||||||
|
Configuration objects inherit from [`Gemma3nVisionConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`Gemma3nVisionConfig`] for more information.
|
||||||
|
|
||||||
|
Config loads imagenet label descriptions and stores them in `id2label` attribute, `label2id` attribute for default
|
||||||
|
imagenet models is set to `None` due to occlusions in the label descriptions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
do_pooling (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to do pooling for the last_hidden_state in `TimmWrapper` or not.
|
||||||
|
architecture (`str`, *optional*, defaults to `"mobilenetv5_300m_enc"`):
|
||||||
|
Determines vision architecture for TimmWrapper.
|
||||||
|
hidden_size (`int`, *optional*, defaults to 2048):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
vocab_size (`int`, *optional*, defaults to 128):
|
||||||
|
Vocabulary size of the additional hard-token embeddings for vision model.
|
||||||
|
vocab_offset (`int`, *optional*, defaults to 262144):
|
||||||
|
Offset between the tokenizer vocab index for the token ids embedded by `Gemma3nMultimodalEmbedder` and the
|
||||||
|
0-indexed `Gemma3nMultimodalEmbedder.embedding` table.
|
||||||
|
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
|
The epsilon used by the rms normalization layers.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> from transformers import Gemma3nVisionConfig, TimmWrapper
|
||||||
|
|
||||||
|
>>> # Initializing a TimmWrapper gemma3n_vision-E4B-style configuration
|
||||||
|
>>> configuration = Gemma3nVisionConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a gemma3n_vision-E4B-style TimmWrapper from the configuration
|
||||||
|
>>> model = TimmWrapper(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_type = "gemma3n_vision"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
initializer_range: float = 0.02,
|
||||||
|
do_pooling: bool = False,
|
||||||
|
architecture: str = "mobilenetv5_300m_enc",
|
||||||
|
hidden_size: int = 2048,
|
||||||
|
vocab_size: int = 128,
|
||||||
|
vocab_offset: int = 262_144,
|
||||||
|
rms_norm_eps: float = 1e-06,
|
||||||
|
model_args: Optional[dict] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.do_pooling = do_pooling
|
||||||
|
self.model_args = model_args # named "model_args" for BC with timm
|
||||||
|
self.architecture = architecture
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.vocab_offset = vocab_offset
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, config_dict: dict[str, Any], **kwargs):
|
||||||
|
label_names = config_dict.get("label_names", None)
|
||||||
|
is_custom_model = "num_labels" in kwargs or "id2label" in kwargs
|
||||||
|
|
||||||
|
# if no labels added to config, use imagenet labeller in timm
|
||||||
|
if label_names is None and not is_custom_model:
|
||||||
|
requires_backends(cls, ["timm"])
|
||||||
|
imagenet_subset = infer_imagenet_subset(config_dict)
|
||||||
|
if imagenet_subset:
|
||||||
|
dataset_info = ImageNetInfo(imagenet_subset)
|
||||||
|
synsets = dataset_info.label_names()
|
||||||
|
label_descriptions = dataset_info.label_descriptions(as_dict=True)
|
||||||
|
label_names = [label_descriptions[synset] for synset in synsets]
|
||||||
|
|
||||||
|
if label_names is not None and not is_custom_model:
|
||||||
|
kwargs["id2label"] = dict(enumerate(label_names))
|
||||||
|
|
||||||
|
# if all label names are unique, create label2id mapping as well
|
||||||
|
if len(set(label_names)) == len(label_names):
|
||||||
|
kwargs["label2id"] = {name: i for i, name in enumerate(label_names)}
|
||||||
|
else:
|
||||||
|
kwargs["label2id"] = None
|
||||||
|
|
||||||
|
# timm config stores the `num_classes` attribute in both the root of config and in the "pretrained_cfg" dict.
|
||||||
|
# We are removing these attributes in order to have the native `transformers` num_labels attribute in config
|
||||||
|
# and to avoid duplicate attributes
|
||||||
|
num_labels_in_kwargs = kwargs.pop("num_labels", None)
|
||||||
|
num_labels_in_dict = config_dict.pop("num_classes", None)
|
||||||
|
|
||||||
|
# passed num_labels has priority over num_classes in config_dict
|
||||||
|
kwargs["num_labels"] = num_labels_in_kwargs or num_labels_in_dict
|
||||||
|
|
||||||
|
# pop num_classes from "pretrained_cfg",
|
||||||
|
# it is not necessary to have it, only root one is used in timm
|
||||||
|
if "pretrained_cfg" in config_dict and "num_classes" in config_dict["pretrained_cfg"]:
|
||||||
|
config_dict["pretrained_cfg"].pop("num_classes", None)
|
||||||
|
|
||||||
|
return super().from_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
output = super().to_dict()
|
||||||
|
output["num_classes"] = self.num_labels
|
||||||
|
output["label_names"] = list(self.id2label.values())
|
||||||
|
output.pop("id2label", None)
|
||||||
|
output.pop("label2id", None)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3nConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`Gemma3nForConditionalGeneration`]. It is used to
|
||||||
|
instantiate a Gemma3nForConditionalGeneration according to the specified arguments, defining the model
|
||||||
|
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
|
||||||
|
Gemma3n-E4B.
|
||||||
|
|
||||||
|
e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B)
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_config (`Union[Gemma3nTextConfig, dict]`, *optional*):
|
||||||
|
The config object of the text backbone.
|
||||||
|
vision_config (`Union[AutoConfig, dict]`, *optional*):
|
||||||
|
Custom vision config or dict.
|
||||||
|
audio_config (`Union[AutoConfig, dict]`, *optional*):
|
||||||
|
Custom audio config or dict.
|
||||||
|
audio_soft_tokens_per_image (`int`, *optional*, defaults to 188):
|
||||||
|
The number of soft tokens per audio clip.
|
||||||
|
vision_soft_tokens_per_image (`int`, *optional*, defaults to 256):
|
||||||
|
The number of soft tokens per image.
|
||||||
|
boi_token_id (`int`, *optional*, defaults to 255999):
|
||||||
|
The begin-of-image token index to wrap the image prompt.
|
||||||
|
eoi_token_id (`int`, *optional*, defaults to 262144):
|
||||||
|
The end-of-image token index to wrap the image prompt.
|
||||||
|
image_token_id (`int`, *optional*, defaults to 262145):
|
||||||
|
The image token index to encode the image prompt.
|
||||||
|
boa_token_id (`int`, *optional*, defaults to 256000):
|
||||||
|
The begin-of-audio token index to wrap the audio prompt.
|
||||||
|
eoa_token_id (`int`, *optional*, defaults to 262272):
|
||||||
|
The end-of-audio token index to wrap the audio prompt.
|
||||||
|
audio_token_id (`int`, *optional*, defaults to 262273):
|
||||||
|
The audio token index to encode the audio prompt.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import Gemma3nForConditionalGeneration, Gemma3nConfig, Gemma3nTextConfig
|
||||||
|
|
||||||
|
>>> # Initializing a MobileNet vision config, which is loaded from TIMM
|
||||||
|
>>> vision_config = Gemma3nVisionConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a Gemma3n Audio config
|
||||||
|
>>> audio_config = Gemma3nAudioConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a Gemma3n Text config
|
||||||
|
>>> text_config = Gemma3nTextConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a Gemma3n gemma-3-4b style configuration
|
||||||
|
>>> configuration = Gemma3nConfig(text_config, vision_config, audio_config)
|
||||||
|
|
||||||
|
>>> # Initializing a model from the gemma-3-4b style configuration
|
||||||
|
>>> model = Gemma3nTextConfig(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "gemma3n"
|
||||||
|
sub_configs = {
|
||||||
|
"text_config": Gemma3nTextConfig,
|
||||||
|
"vision_config": Gemma3nVisionConfig,
|
||||||
|
"audio_config": Gemma3nAudioConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
text_config: Optional[Union[Gemma3nTextConfig, dict[str, Any]]] = None,
|
||||||
|
vision_config: Optional[Union[Gemma3nVisionConfig, dict[str, Any]]] = None,
|
||||||
|
audio_config: Optional[Union[Gemma3nAudioConfig, dict[str, Any]]] = None,
|
||||||
|
audio_soft_tokens_per_image: int = 188,
|
||||||
|
vision_soft_tokens_per_image: int = 256,
|
||||||
|
boi_token_id: int = 255_999,
|
||||||
|
eoi_token_id: int = 262_144,
|
||||||
|
image_token_id: int = 262_145,
|
||||||
|
boa_token_id: int = 256_000,
|
||||||
|
eoa_token_id: int = 262_272,
|
||||||
|
audio_token_id: int = 262_273,
|
||||||
|
initializer_range: float = 0.02,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
if isinstance(text_config, dict):
|
||||||
|
text_config = Gemma3nTextConfig(**text_config)
|
||||||
|
elif text_config is None:
|
||||||
|
text_config = Gemma3nTextConfig()
|
||||||
|
logger.info("text_config is None. Using default Gemma3nTextConfig.")
|
||||||
|
|
||||||
|
if isinstance(vision_config, dict):
|
||||||
|
vision_config = Gemma3nVisionConfig(**vision_config)
|
||||||
|
elif vision_config is None:
|
||||||
|
vision_config = Gemma3nVisionConfig()
|
||||||
|
logger.info("vision_config is None. Using default Gemma3nVisionConfig.")
|
||||||
|
|
||||||
|
if isinstance(audio_config, dict):
|
||||||
|
audio_config = Gemma3nAudioConfig(**audio_config)
|
||||||
|
elif audio_config is None:
|
||||||
|
audio_config = Gemma3nAudioConfig()
|
||||||
|
logger.info("audio_config is None. Using default Gemma3nAudioConfig.")
|
||||||
|
|
||||||
|
self.text_config = text_config
|
||||||
|
self.vision_config = vision_config
|
||||||
|
self.audio_config = audio_config
|
||||||
|
|
||||||
|
self.audio_soft_tokens_per_image = audio_soft_tokens_per_image
|
||||||
|
self.vision_soft_tokens_per_image = vision_soft_tokens_per_image
|
||||||
|
self.boi_token_id = boi_token_id
|
||||||
|
self.eoi_token_id = eoi_token_id
|
||||||
|
self.image_token_id = image_token_id
|
||||||
|
self.boa_token_id = boa_token_id
|
||||||
|
self.eoa_token_id = eoa_token_id
|
||||||
|
self.audio_token_id = audio_token_id
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Gemma3nAudioConfig", "Gemma3nConfig", "Gemma3nTextConfig", "Gemma3nVisionConfig"]
|
||||||
807
src/transformers/models/gemma3n/convert_gemma3n_weights.py
Normal file
807
src/transformers/models/gemma3n/convert_gemma3n_weights.py
Normal file
@@ -0,0 +1,807 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
r"""Utility to convert Gemma models from Orbax to HF Transformers checkpoint.
|
||||||
|
|
||||||
|
python src/transformers/models/gemma3n/convert_gemma3n_weights.py \
|
||||||
|
--variant='gemma3n_e4b' \
|
||||||
|
--tokenizer_path="$HOME/nano3/checkpoints/tokenizer/gemma-3n-tokenizer.model" \
|
||||||
|
--checkpoint_path="$HOME/nano3/checkpoints/g251_orbax/" \
|
||||||
|
--output_path="$HOME/nano3/checkpoints/g251_vision_encoder/"
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from collections.abc import Iterable, Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import accelerate
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import tree
|
||||||
|
from absl import app, flags, logging
|
||||||
|
from orbax import checkpoint as obc
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
Gemma3nAudioConfig,
|
||||||
|
Gemma3nAudioFeatureExtractor,
|
||||||
|
Gemma3nConfig,
|
||||||
|
Gemma3nForConditionalGeneration,
|
||||||
|
Gemma3nProcessor,
|
||||||
|
Gemma3nTextConfig,
|
||||||
|
Gemma3nVisionConfig,
|
||||||
|
GemmaTokenizerFast,
|
||||||
|
GenerationConfig,
|
||||||
|
SiglipImageProcessorFast,
|
||||||
|
)
|
||||||
|
from transformers.image_utils import PILImageResampling
|
||||||
|
|
||||||
|
|
||||||
|
# ==== Internal Constants and Classes ====
|
||||||
|
|
||||||
|
|
||||||
|
_CHAT_TEMPLATE = """{{ bos_token }}
|
||||||
|
{%- if messages[0]['role'] == 'system' -%}
|
||||||
|
{%- if messages[0]['content'] is string -%}
|
||||||
|
{%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}
|
||||||
|
{%- else -%}
|
||||||
|
{%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- set loop_messages = messages[1:] -%}
|
||||||
|
{%- else -%}
|
||||||
|
{%- set first_user_prefix = "" -%}
|
||||||
|
{%- set loop_messages = messages -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- for message in loop_messages -%}
|
||||||
|
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
|
||||||
|
{{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if (message['role'] == 'assistant') -%}
|
||||||
|
{%- set role = "model" -%}
|
||||||
|
{%- else -%}
|
||||||
|
{%- set role = message['role'] -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else "") }}
|
||||||
|
{%- if message['content'] is string -%}
|
||||||
|
{{ message['content'] | trim }}
|
||||||
|
{%- elif message['content'] is iterable -%}
|
||||||
|
{%- for item in message['content'] -%}
|
||||||
|
{%- if item['type'] == 'audio' -%}
|
||||||
|
{{ '<audio_soft_token>' }}
|
||||||
|
{%- elif item['type'] == 'image' -%}
|
||||||
|
{{ '<image_soft_token>' }}
|
||||||
|
{%- elif item['type'] == 'text' -%}
|
||||||
|
{{ item['text'] | trim }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- else -%}
|
||||||
|
{{ raise_exception("Invalid content type") }}
|
||||||
|
{%- endif -%}
|
||||||
|
{{ '<end_of_turn>\n' }}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- if add_generation_prompt -%}
|
||||||
|
{{'<start_of_turn>model\n'}}
|
||||||
|
{%- endif -%}
|
||||||
|
"""
|
||||||
|
|
||||||
|
_DTYPES = {"float32", "bfloat16", "float16"}
|
||||||
|
|
||||||
|
_SLIDING_WINDOW_PATTERN = 5
|
||||||
|
|
||||||
|
_AUDIO_ENCODER_PARAMETER = "AudioEncoder/encoder"
|
||||||
|
_AUDIO_ENCODER_CONFORMER = f"{_AUDIO_ENCODER_PARAMETER}/conformer/stacked_layers"
|
||||||
|
_AUDIO_ENCODER_SSCP = f"{_AUDIO_ENCODER_PARAMETER}/feature"
|
||||||
|
|
||||||
|
_TRANSFORMER_PARAMETER = "transformer"
|
||||||
|
_TRANSFORMER_ALTUP_PROJ = f"{_TRANSFORMER_PARAMETER}/altup_projection_"
|
||||||
|
_TRANSFORMER_ALTUP_UNEMB = f"{_TRANSFORMER_PARAMETER}/altup_unembed_projection_"
|
||||||
|
_TRANSFORMER_DECODER_BLOCK = f"{_TRANSFORMER_PARAMETER}/stacked_layers/attention_type_"
|
||||||
|
_TRANSFORMER_DECODER_BLOCK_LEN = len(_TRANSFORMER_DECODER_BLOCK)
|
||||||
|
_TRANSFORMER_EMBEDDER = f"{_TRANSFORMER_PARAMETER}/embedder"
|
||||||
|
_TRANSFORMER_FINAL_NORM = "transformer/final_norm"
|
||||||
|
_TRANSFORMER_POST_TRAINING_PREFIX = "rlx_networks/policy_network/"
|
||||||
|
_TRANSFORMER_POST_TRAINING_PREFIX_LEN = len(_TRANSFORMER_POST_TRAINING_PREFIX)
|
||||||
|
|
||||||
|
# _MOBILE_NET_CONFIG = Gemma3nVisionConfig.from_pretrained("")
|
||||||
|
|
||||||
|
_MOBILE_NET_PREFIX = "mobilenet"
|
||||||
|
_MOBILE_NET_TIMM_SUMMED_BLOCK_SIZES = [3, 8, 45, 84]
|
||||||
|
_MOBILE_NET_CONV = "block_group_conv2d_"
|
||||||
|
_MOBILE_NET_FIB = "block_group_fused_ib_"
|
||||||
|
_MOBILE_NET_MQA = "block_group_mmqa_"
|
||||||
|
_MOBILE_NET_MSFA = "block_adapter_"
|
||||||
|
_MOBILE_NET_UIB = "block_group_uib_"
|
||||||
|
_MOBILE_NET_UIB_HAS_DW_START = {
|
||||||
|
(1, 0),
|
||||||
|
(1, 1),
|
||||||
|
(1, 2),
|
||||||
|
(1, 3),
|
||||||
|
(1, 4),
|
||||||
|
(2, 0),
|
||||||
|
(2, 1),
|
||||||
|
(2, 2),
|
||||||
|
(2, 3),
|
||||||
|
(2, 4),
|
||||||
|
(2, 5),
|
||||||
|
(2, 6),
|
||||||
|
(2, 7),
|
||||||
|
(3, 0),
|
||||||
|
}
|
||||||
|
_MOBILE_NET_UIB_HAS_DW_MID = {
|
||||||
|
(1, 0),
|
||||||
|
(2, 0),
|
||||||
|
(3, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
_VARIANT_GEMMA_3_2B = "gemma3n_e2b"
|
||||||
|
_VARIANT_GEMMA_3_4B = "gemma3n_e4b"
|
||||||
|
_VARIANTS: Mapping[str, Gemma3nConfig] = {
|
||||||
|
_VARIANT_GEMMA_3_2B: Gemma3nConfig(
|
||||||
|
text_config=Gemma3nTextConfig(
|
||||||
|
intermediate_size=2048 * 4,
|
||||||
|
num_hidden_layers=30,
|
||||||
|
activation_sparsity_pattern=(0.95,) * 10 + (0.0,) * 20,
|
||||||
|
num_kv_shared_layers=10,
|
||||||
|
),
|
||||||
|
vision_config=Gemma3nVisionConfig(),
|
||||||
|
audio_config=Gemma3nAudioConfig(),
|
||||||
|
),
|
||||||
|
_VARIANT_GEMMA_3_4B: Gemma3nConfig(
|
||||||
|
text_config=Gemma3nTextConfig(),
|
||||||
|
vision_config=Gemma3nVisionConfig(),
|
||||||
|
audio_config=Gemma3nAudioConfig(),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ==== Flags ====
|
||||||
|
|
||||||
|
_AUDIO_DTYPE = flags.DEFINE_enum(
|
||||||
|
name="audio_dtype",
|
||||||
|
default="bfloat16",
|
||||||
|
help="The floating point precision (aka dtype) of the model.",
|
||||||
|
enum_values=_DTYPES,
|
||||||
|
)
|
||||||
|
|
||||||
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
||||||
|
name="checkpoint_path",
|
||||||
|
default=None,
|
||||||
|
help="Path to the Orbax checkpoint.",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
_INCLUDE_CHAT_TEMPLATE = flags.DEFINE_bool(
|
||||||
|
name="include_chat_template", default=False, help="If true, will save the default chat template with the tokenizer"
|
||||||
|
)
|
||||||
|
|
||||||
|
_OUTPUT_PATH = flags.DEFINE_string(
|
||||||
|
name="output_path",
|
||||||
|
default=None,
|
||||||
|
help="Path to store the HF checkpoint.",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
_TRANSFORMER_DTYPE = flags.DEFINE_enum(
|
||||||
|
name="text_dtype",
|
||||||
|
default="bfloat16",
|
||||||
|
help="The floating point precision (aka dtype) of the model.",
|
||||||
|
enum_values=_DTYPES,
|
||||||
|
)
|
||||||
|
|
||||||
|
_TOKENIZER_PATH = flags.DEFINE_string(
|
||||||
|
name="tokenizer_path",
|
||||||
|
default=None,
|
||||||
|
help="Path to the SentencePiece model file.",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
_VARIANT = flags.DEFINE_enum(
|
||||||
|
name="variant",
|
||||||
|
default=_VARIANT_GEMMA_3_4B,
|
||||||
|
help="The model variant to convert.",
|
||||||
|
enum_values=set(_VARIANTS.keys()),
|
||||||
|
)
|
||||||
|
|
||||||
|
_VERBOSE = flags.DEFINE_bool(
|
||||||
|
name="verbose",
|
||||||
|
default=False,
|
||||||
|
help="If true, log the path, shape, and dtype of every converted layer.",
|
||||||
|
)
|
||||||
|
|
||||||
|
_VISION_DTYPE = flags.DEFINE_enum(
|
||||||
|
name="vision_dtype",
|
||||||
|
default="bfloat16",
|
||||||
|
help="The floating point precision (aka dtype) of the model.",
|
||||||
|
enum_values=_DTYPES,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_audio_encoder_weights(
|
||||||
|
config: Gemma3nAudioConfig,
|
||||||
|
path: str,
|
||||||
|
param: str,
|
||||||
|
weights: np.ndarray,
|
||||||
|
) -> Iterable[tuple[str, np.ndarray]]:
|
||||||
|
converted_paths: list[str] = []
|
||||||
|
converted_weights: list[Any] = []
|
||||||
|
|
||||||
|
if path.startswith(_AUDIO_ENCODER_CONFORMER):
|
||||||
|
assert weights.shape[0] == config.conf_num_hidden_layers
|
||||||
|
|
||||||
|
for i, matrix in enumerate(weights):
|
||||||
|
if "fflayer_end" in path:
|
||||||
|
base = f"conformer.{i}.ffw_layer_end"
|
||||||
|
|
||||||
|
if path.endswith("ffn_layer1"):
|
||||||
|
converted_paths.append(f"{base}.ffw_layer_1.weight")
|
||||||
|
converted_weights.append(matrix.transpose())
|
||||||
|
elif path.endswith("ffn_layer2"):
|
||||||
|
converted_paths.append(f"{base}.ffw_layer_2.weight")
|
||||||
|
converted_weights.append(matrix.transpose())
|
||||||
|
elif path.endswith("post_layer_norm"):
|
||||||
|
converted_paths.append(f"{base}.post_layer_norm.weight")
|
||||||
|
converted_weights.append(matrix)
|
||||||
|
elif path.endswith("pre_layer_norm"):
|
||||||
|
converted_paths.append(f"{base}.pre_layer_norm.weight")
|
||||||
|
converted_weights.append(matrix)
|
||||||
|
elif "fflayer_start" in path:
|
||||||
|
base = f"conformer.{i}.ffw_layer_start"
|
||||||
|
|
||||||
|
if path.endswith("ffn_layer1"):
|
||||||
|
converted_paths.append(f"{base}.ffw_layer_1.weight")
|
||||||
|
converted_weights.append(matrix.transpose())
|
||||||
|
elif path.endswith("ffn_layer2"):
|
||||||
|
converted_paths.append(f"{base}.ffw_layer_2.weight")
|
||||||
|
converted_weights.append(matrix.transpose())
|
||||||
|
elif path.endswith("post_layer_norm"):
|
||||||
|
converted_paths.append(f"{base}.post_layer_norm.weight")
|
||||||
|
converted_weights.append(matrix)
|
||||||
|
elif path.endswith("pre_layer_norm"):
|
||||||
|
converted_paths.append(f"{base}.pre_layer_norm.weight")
|
||||||
|
converted_weights.append(matrix)
|
||||||
|
elif path.endswith("final_ln"):
|
||||||
|
converted_paths.append(f"conformer.{i}.norm.weight")
|
||||||
|
converted_weights.append(matrix)
|
||||||
|
elif "lconv" in path:
|
||||||
|
base = f"conformer.{i}.lconv1d"
|
||||||
|
|
||||||
|
if path.endswith("conv_norm"):
|
||||||
|
converted_paths.append(f"{base}.conv_norm.weight")
|
||||||
|
converted_weights.append(matrix)
|
||||||
|
elif path.endswith("depthwise_conv1d"):
|
||||||
|
converted_paths.append(f"{base}.depthwise_conv1d.weight")
|
||||||
|
converted_weights.append(matrix.transpose())
|
||||||
|
elif path.endswith("linear_end"):
|
||||||
|
converted_paths.append(f"{base}.linear_end.weight")
|
||||||
|
converted_weights.append(matrix.transpose())
|
||||||
|
elif path.endswith("linear_start"):
|
||||||
|
converted_paths.append(f"{base}.linear_start.weight")
|
||||||
|
converted_weights.append(matrix.transpose())
|
||||||
|
elif path.endswith("ln"):
|
||||||
|
converted_paths.append(f"{base}.pre_layer_norm.weight")
|
||||||
|
converted_weights.append(matrix)
|
||||||
|
elif "trans_atten" in path:
|
||||||
|
base = f"conformer.{i}.attention"
|
||||||
|
|
||||||
|
if param == "per_dim_scale":
|
||||||
|
converted_paths.append(f"{base}.attn.per_dim_scale")
|
||||||
|
converted_weights.append(matrix)
|
||||||
|
|
||||||
|
if path.endswith("query_key_value_projection"):
|
||||||
|
converted_paths.extend(
|
||||||
|
[f"{base}.attn.q_proj.weight", f"{base}.attn.k_proj.weight", f"{base}.attn.v_proj.weight"]
|
||||||
|
)
|
||||||
|
converted_weights.extend(
|
||||||
|
[
|
||||||
|
m.reshape(config.hidden_size, config.hidden_size).transpose()
|
||||||
|
for m in matrix.transpose(1, 0, 2, 3)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif path.endswith("pos_proj"):
|
||||||
|
converted_paths.append(f"{base}.attn.relative_position_embedding.pos_proj.weight")
|
||||||
|
converted_weights.append(matrix.reshape(config.hidden_size, config.hidden_size).transpose())
|
||||||
|
elif path.endswith("post"):
|
||||||
|
converted_paths.append(f"{base}.post.weight")
|
||||||
|
converted_weights.append(matrix.transpose(2, 0, 1).reshape(config.hidden_size, config.hidden_size))
|
||||||
|
elif path.endswith("post_norm"):
|
||||||
|
converted_paths.append(f"{base}.post_norm.weight")
|
||||||
|
converted_weights.append(matrix)
|
||||||
|
elif path.endswith("pre_norm"):
|
||||||
|
converted_paths.append(f"{base}.pre_attn_norm.weight")
|
||||||
|
converted_weights.append(matrix)
|
||||||
|
elif path.startswith(_AUDIO_ENCODER_SSCP):
|
||||||
|
if path.endswith("input_proj"):
|
||||||
|
converted_paths.append("subsample_conv_projection.input_proj_linear.weight")
|
||||||
|
converted_weights.append(
|
||||||
|
weights.transpose(2, 0, 1).reshape(config.hidden_size, config.sscp_conv_channel_size[1] ** 2)
|
||||||
|
)
|
||||||
|
elif "norm_" in path:
|
||||||
|
index = int(path[-1])
|
||||||
|
converted_paths.append(f"subsample_conv_projection.conv_{index}.norm.weight")
|
||||||
|
converted_weights.append(weights)
|
||||||
|
elif "subsampling_" in path:
|
||||||
|
index = int(path[-1])
|
||||||
|
converted_paths.append(f"subsample_conv_projection.conv_{index}.conv.weight")
|
||||||
|
converted_weights.append(weights.transpose(3, 2, 0, 1))
|
||||||
|
|
||||||
|
if (cpl := len(converted_paths)) != (cwl := len(converted_weights)):
|
||||||
|
raise ValueError(
|
||||||
|
"The `converted_paths` and `converted_weights` should be the same "
|
||||||
|
f"length. Got {cpl} and {cwl}, respectively, for {path}."
|
||||||
|
)
|
||||||
|
|
||||||
|
return zip(converted_paths, converted_weights)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_transformer_weights(
|
||||||
|
config: Gemma3nTextConfig,
|
||||||
|
path: str,
|
||||||
|
param: str,
|
||||||
|
weights: np.ndarray,
|
||||||
|
) -> Iterable[tuple[str, np.ndarray]]:
|
||||||
|
if path.startswith(_TRANSFORMER_POST_TRAINING_PREFIX):
|
||||||
|
path = path[_TRANSFORMER_POST_TRAINING_PREFIX_LEN:]
|
||||||
|
|
||||||
|
converted_paths: list[str] = []
|
||||||
|
converted_weights: list[Any] = []
|
||||||
|
|
||||||
|
if path.startswith(_TRANSFORMER_ALTUP_PROJ):
|
||||||
|
index = int(path[-1])
|
||||||
|
converted_paths.append(f"altup_projections.{index}.weight")
|
||||||
|
converted_weights.append(weights.transpose())
|
||||||
|
elif path.startswith(_TRANSFORMER_ALTUP_UNEMB):
|
||||||
|
index = int(path[-1])
|
||||||
|
converted_paths.append(f"altup_unembed_projections.{index}.weight")
|
||||||
|
converted_weights.append(weights.transpose())
|
||||||
|
elif path.startswith(_TRANSFORMER_DECODER_BLOCK):
|
||||||
|
attention_type_index = int(path[_TRANSFORMER_DECODER_BLOCK_LEN])
|
||||||
|
assert weights.shape[0] == config.num_hidden_layers / _SLIDING_WINDOW_PATTERN
|
||||||
|
|
||||||
|
for i, matrix in enumerate(weights):
|
||||||
|
layer_idx = _SLIDING_WINDOW_PATTERN * i + attention_type_index
|
||||||
|
base_path = f"layers.{layer_idx}"
|
||||||
|
|
||||||
|
if "altup" in path:
|
||||||
|
altup_path = f"{base_path}.altup"
|
||||||
|
|
||||||
|
if param == "correct_output_scale":
|
||||||
|
converted_paths.append(f"{altup_path}.correct_output_scale")
|
||||||
|
converted_weights.append(matrix)
|
||||||
|
elif param == "correction_coefs":
|
||||||
|
converted_paths.append(f"{altup_path}.correction_coefs.weight")
|
||||||
|
converted_weights.append(matrix.transpose())
|
||||||
|
elif param == "prediction_coefs":
|
||||||
|
converted_paths.append(f"{altup_path}.prediction_coefs.weight")
|
||||||
|
converted_weights.append(
|
||||||
|
np.clip(
|
||||||
|
matrix.reshape(config.altup_num_inputs, config.altup_num_inputs**2).transpose(),
|
||||||
|
-config.altup_coef_clip,
|
||||||
|
config.altup_coef_clip,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if path.endswith("modality_router"):
|
||||||
|
converted_paths.append(f"{altup_path}.modality_router.weight")
|
||||||
|
converted_weights.append(matrix.transpose())
|
||||||
|
elif path.endswith("router_norm_layer"):
|
||||||
|
converted_paths.append(f"{altup_path}.router_norm.weight")
|
||||||
|
converted_weights.append(matrix)
|
||||||
|
elif path.endswith("attn/attn_vec_einsum"):
|
||||||
|
converted_paths.append(f"{base_path}.self_attn.o_proj.weight")
|
||||||
|
converted_weights.append(
|
||||||
|
matrix.transpose(2, 0, 1).reshape(config.hidden_size, config.num_attention_heads * config.head_dim)
|
||||||
|
)
|
||||||
|
elif path.endswith("attn/kv_einsum"):
|
||||||
|
converted_paths.extend(
|
||||||
|
[
|
||||||
|
f"{base_path}.self_attn.k_proj.weight",
|
||||||
|
f"{base_path}.self_attn.v_proj.weight",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
k_proj_weights, v_proj_weights = matrix.transpose(0, 2, 1, 3)
|
||||||
|
kv_proj_shape = (config.hidden_size, config.num_key_value_heads * config.head_dim)
|
||||||
|
converted_weights.extend(
|
||||||
|
[
|
||||||
|
k_proj_weights.reshape(kv_proj_shape).transpose(),
|
||||||
|
v_proj_weights.reshape(kv_proj_shape).transpose(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif path.endswith("attn/q_einsum"):
|
||||||
|
converted_paths.append(f"{base_path}.self_attn.q_proj.weight")
|
||||||
|
converted_weights.append(
|
||||||
|
matrix.transpose(1, 0, 2)
|
||||||
|
.reshape(config.hidden_size, config.num_attention_heads * config.head_dim)
|
||||||
|
.transpose()
|
||||||
|
)
|
||||||
|
elif path.endswith("attn/query_norm"):
|
||||||
|
converted_paths.append(f"{base_path}.self_attn.q_norm.weight")
|
||||||
|
converted_weights.append(matrix)
|
||||||
|
elif path.endswith("attn/key_norm"):
|
||||||
|
converted_paths.append(f"{base_path}.self_attn.k_norm.weight")
|
||||||
|
converted_weights.append(matrix)
|
||||||
|
elif path.endswith("laurel_block/linear_left"):
|
||||||
|
converted_paths.append(f"{base_path}.laurel.linear_left.weight")
|
||||||
|
converted_weights.append(matrix.transpose())
|
||||||
|
elif path.endswith("laurel_block/linear_right"):
|
||||||
|
converted_paths.append(f"{base_path}.laurel.linear_right.weight")
|
||||||
|
converted_weights.append(matrix.transpose())
|
||||||
|
elif path.endswith("mlp/gating_einsum"):
|
||||||
|
converted_paths.extend([f"{base_path}.mlp.gate_proj.weight", f"{base_path}.mlp.up_proj.weight"])
|
||||||
|
gate_proj_weight, up_proj_weight = matrix
|
||||||
|
converted_weights.extend([gate_proj_weight, up_proj_weight])
|
||||||
|
elif path.endswith("mlp/linear"):
|
||||||
|
converted_paths.append(f"{base_path}.mlp.down_proj.weight")
|
||||||
|
converted_weights.append(matrix.transpose())
|
||||||
|
elif path.endswith("per_layer_input_gate"):
|
||||||
|
converted_paths.append(f"{base_path}.per_layer_input_gate.weight")
|
||||||
|
converted_weights.append(matrix.transpose())
|
||||||
|
elif path.endswith("per_layer_projection"):
|
||||||
|
converted_paths.append(f"{base_path}.per_layer_projection.weight")
|
||||||
|
converted_weights.append(matrix.transpose())
|
||||||
|
elif path.endswith("post_attention_norm"):
|
||||||
|
converted_paths.append(f"{base_path}.post_attention_layernorm.weight")
|
||||||
|
converted_weights.append(matrix)
|
||||||
|
elif path.endswith("post_ffw_norm"):
|
||||||
|
converted_paths.append(f"{base_path}.post_feedforward_layernorm.weight")
|
||||||
|
converted_weights.append(matrix)
|
||||||
|
elif path.endswith("post_laurel_norm"):
|
||||||
|
converted_paths.append(f"{base_path}.laurel.post_laurel_norm.weight")
|
||||||
|
converted_weights.append(matrix)
|
||||||
|
elif path.endswith("post_per_layer_input_norm"):
|
||||||
|
converted_paths.append(f"{base_path}.post_per_layer_input_norm.weight")
|
||||||
|
converted_weights.append(matrix)
|
||||||
|
elif path.endswith("pre_attention_norm"):
|
||||||
|
converted_paths.append(f"{base_path}.input_layernorm.weight")
|
||||||
|
converted_weights.append(matrix)
|
||||||
|
elif path.endswith("pre_ffw_norm"):
|
||||||
|
converted_paths.append(f"{base_path}.pre_feedforward_layernorm.weight")
|
||||||
|
converted_weights.append(matrix)
|
||||||
|
elif path == _TRANSFORMER_EMBEDDER:
|
||||||
|
if param == "input_embedding":
|
||||||
|
converted_paths.append("embed_tokens.weight")
|
||||||
|
# Gemma 3n model doesn't have soft tokens or "end of" tokens for images and audio in its input and output
|
||||||
|
# embeddings, so we resize to avoid bugs observed with Mllama
|
||||||
|
pre_expansion_embeddings = weights
|
||||||
|
pad_token_slice = slice(config.pad_token_id, config.pad_token_id + 1)
|
||||||
|
new_embeddings = np.repeat(pre_expansion_embeddings[pad_token_slice], 256, axis=0)
|
||||||
|
weights = np.vstack([pre_expansion_embeddings, new_embeddings])
|
||||||
|
converted_weights.append(weights)
|
||||||
|
elif param == "per_layer_embeddings":
|
||||||
|
converted_paths.append("embed_tokens_per_layer.weight")
|
||||||
|
converted_weights.append(
|
||||||
|
weights.reshape(
|
||||||
|
config.vocab_size_per_layer_input, config.num_hidden_layers * config.hidden_size_per_layer_input
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif path.startswith(_TRANSFORMER_EMBEDDER):
|
||||||
|
# TODO: ryanmullins - support multimodal norms and projections
|
||||||
|
if path.endswith("per_layer_model_projection"):
|
||||||
|
converted_paths.append("per_layer_model_projection.weight")
|
||||||
|
converted_weights.append(
|
||||||
|
weights.reshape(
|
||||||
|
config.hidden_size, config.num_hidden_layers * config.hidden_size_per_layer_input
|
||||||
|
).transpose()
|
||||||
|
)
|
||||||
|
elif path.endswith("per_layer_projection_norm"):
|
||||||
|
converted_paths.append("per_layer_projection_norm.weight")
|
||||||
|
converted_weights.append(weights)
|
||||||
|
elif path == _TRANSFORMER_FINAL_NORM:
|
||||||
|
converted_paths = ["norm.weight"]
|
||||||
|
converted_weights = [weights]
|
||||||
|
|
||||||
|
if (cpl := len(converted_paths)) != (cwl := len(converted_weights)):
|
||||||
|
raise ValueError(
|
||||||
|
"The `converted_paths` and `converted_weights` should be the same "
|
||||||
|
f"length. Got {cpl} and {cwl}, respectively, for {path}."
|
||||||
|
)
|
||||||
|
|
||||||
|
return zip(converted_paths, converted_weights)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_vision_weights(
|
||||||
|
config: Gemma3nVisionConfig,
|
||||||
|
path: str,
|
||||||
|
param: str,
|
||||||
|
weights: np.ndarray,
|
||||||
|
) -> Iterable[tuple[str, np.ndarray]]:
|
||||||
|
def generate_base_path(path: str, block_type: str) -> tuple[str, tuple[int, int]]:
|
||||||
|
re_str = r"{}(\d+)/".format(block_type)
|
||||||
|
re_pattern = re.compile(re_str)
|
||||||
|
match = re.search(re_pattern, path).group(1)
|
||||||
|
idx = abs(int(match)) - 1
|
||||||
|
|
||||||
|
for block_idx, v in enumerate(_MOBILE_NET_TIMM_SUMMED_BLOCK_SIZES):
|
||||||
|
if v > idx:
|
||||||
|
offset = _MOBILE_NET_TIMM_SUMMED_BLOCK_SIZES[block_idx - 1] if block_idx > 0 else 0
|
||||||
|
layer_idx = idx - offset
|
||||||
|
return f"blocks.{block_idx}.{layer_idx}", (block_idx, layer_idx)
|
||||||
|
|
||||||
|
raise ValueError(f"could not extract a base path from {path}")
|
||||||
|
|
||||||
|
if _MOBILE_NET_MSFA in path:
|
||||||
|
converted_path = "msfa"
|
||||||
|
|
||||||
|
if "ffn/Normalize_0" in path:
|
||||||
|
converted_path += ".ffn.pw_exp.bn.weight"
|
||||||
|
converted_weight = weights
|
||||||
|
elif "ffn/Normalize_1" in path:
|
||||||
|
converted_path += ".ffn.pw_proj.bn.weight"
|
||||||
|
converted_weight = weights
|
||||||
|
elif "ffn/expand" in path:
|
||||||
|
converted_path += ".ffn.pw_exp.conv.weight"
|
||||||
|
converted_weight = weights.transpose()[:, :, None, None]
|
||||||
|
elif "ffn/project" in path:
|
||||||
|
converted_path += ".ffn.pw_proj.conv.weight"
|
||||||
|
converted_weight = weights.transpose()[:, :, None, None]
|
||||||
|
elif "Normalize_0" in path:
|
||||||
|
converted_path += ".norm.weight"
|
||||||
|
converted_weight = weights
|
||||||
|
elif _MOBILE_NET_CONV in path:
|
||||||
|
if "Conv_0" in path:
|
||||||
|
converted_path = "conv_stem.conv.weight"
|
||||||
|
converted_weight = weights.transpose(3, 2, 1, 0)
|
||||||
|
elif "Normalize_0" in path:
|
||||||
|
converted_path = "conv_stem.bn.weight"
|
||||||
|
converted_weight = weights
|
||||||
|
elif _MOBILE_NET_FIB in path:
|
||||||
|
converted_path, _ = generate_base_path(path, _MOBILE_NET_FIB)
|
||||||
|
if "Normalize_0" in path:
|
||||||
|
converted_path += ".bn1.weight"
|
||||||
|
converted_weight = weights
|
||||||
|
elif "Normalize_1" in path:
|
||||||
|
converted_path += ".bn2.weight"
|
||||||
|
converted_weight = weights
|
||||||
|
elif "expand_conv" in path:
|
||||||
|
converted_path += ".conv_exp.weight"
|
||||||
|
converted_weight = weights.transpose(3, 2, 1, 0)
|
||||||
|
else:
|
||||||
|
converted_path += ".conv_pwl.weight"
|
||||||
|
converted_weight = weights.transpose()[:, :, None, None]
|
||||||
|
elif _MOBILE_NET_MQA in path:
|
||||||
|
converted_path, _ = generate_base_path(path, _MOBILE_NET_MQA)
|
||||||
|
|
||||||
|
if "LayerScale_0" in path:
|
||||||
|
converted_path += ".layer_scale.gamma"
|
||||||
|
converted_weight = weights
|
||||||
|
elif "Normalize_0" in path:
|
||||||
|
converted_path += ".norm.weight"
|
||||||
|
converted_weight = weights
|
||||||
|
elif "Normalize_1" in path:
|
||||||
|
converted_path += ".attn.key.norm.weight"
|
||||||
|
converted_weight = weights
|
||||||
|
elif "Normalize_2" in path:
|
||||||
|
converted_path += ".attn.value.norm.weight"
|
||||||
|
converted_weight = weights
|
||||||
|
elif "key_dwconv" in path:
|
||||||
|
converted_path += ".attn.key.down_conv.weight"
|
||||||
|
converted_weight = weights.transpose()
|
||||||
|
elif "key_proj" in path:
|
||||||
|
converted_path += ".attn.key.proj.weight"
|
||||||
|
converted_weight = weights.transpose()[:, :, None, None]
|
||||||
|
elif "output_proj" in path:
|
||||||
|
converted_path += ".attn.output.proj.weight"
|
||||||
|
converted_weight = weights.transpose()[:, :, None, None]
|
||||||
|
elif "query_proj" in path:
|
||||||
|
converted_path += ".attn.query.proj.weight"
|
||||||
|
converted_weight = weights.transpose()[:, :, None, None]
|
||||||
|
elif "value_dwconv" in path:
|
||||||
|
converted_path += ".attn.value.down_conv.weight"
|
||||||
|
converted_weight = weights.transpose()
|
||||||
|
elif "value_proj" in path:
|
||||||
|
converted_path += ".attn.value.proj.weight"
|
||||||
|
converted_weight = weights.transpose()[:, :, None, None]
|
||||||
|
elif _MOBILE_NET_UIB in path:
|
||||||
|
converted_path, idx_key = generate_base_path(path, _MOBILE_NET_UIB)
|
||||||
|
|
||||||
|
has_dw_start = idx_key in _MOBILE_NET_UIB_HAS_DW_START
|
||||||
|
has_dw_mid = idx_key in _MOBILE_NET_UIB_HAS_DW_MID
|
||||||
|
|
||||||
|
if "LayerScale_0" in path:
|
||||||
|
converted_path += ".layer_scale.gamma"
|
||||||
|
converted_weight = weights
|
||||||
|
elif "Normalize_0" in path:
|
||||||
|
converted_path += ".dw_start.bn.weight" if has_dw_start else ".pw_exp.bn.weight"
|
||||||
|
converted_weight = weights
|
||||||
|
elif "Normalize_1" in path:
|
||||||
|
converted_path += ".pw_exp.bn.weight" if has_dw_start else ".pw_proj.bn.weight"
|
||||||
|
converted_weight = weights
|
||||||
|
elif "Normalize_2" in path:
|
||||||
|
converted_path += ".dw_mid.bn.weight" if has_dw_mid else ".pw_proj.bn.weight"
|
||||||
|
converted_weight = weights
|
||||||
|
elif "Normalize_3" in path:
|
||||||
|
converted_path += ".pw_proj.bn.weight"
|
||||||
|
converted_weight = weights
|
||||||
|
elif "expand" in path:
|
||||||
|
converted_path += ".pw_exp.conv.weight"
|
||||||
|
converted_weight = weights.transpose()[:, :, None, None]
|
||||||
|
elif "middle_dwconv" in path:
|
||||||
|
converted_path += ".dw_mid.conv.weight"
|
||||||
|
converted_weight = weights.transpose(3, 2, 1, 0)
|
||||||
|
elif "project" in path:
|
||||||
|
converted_path += ".pw_proj.conv.weight"
|
||||||
|
converted_weight = weights.transpose()[:, :, None, None]
|
||||||
|
elif "start_dwconv" in path:
|
||||||
|
converted_path += ".dw_start.conv.weight"
|
||||||
|
converted_weight = weights.transpose(3, 2, 1, 0)
|
||||||
|
|
||||||
|
return [(converted_path, converted_weight)]
|
||||||
|
|
||||||
|
|
||||||
|
def convert(checkpoint_path: str, config: Gemma3nConfig) -> dict[str, torch.Tensor]:
|
||||||
|
"""Loads Orbax checkpoint from `input_path` and converts it to HF tree."""
|
||||||
|
checkpointer = obc.PyTreeCheckpointer()
|
||||||
|
ckpt = checkpointer.restore(checkpoint_path)
|
||||||
|
hf_tree: dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
|
def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> None:
|
||||||
|
hf_tree[path] = torch.from_numpy(weights.astype("float32")).type(target_dtype)
|
||||||
|
if _VERBOSE.value:
|
||||||
|
logging.info(
|
||||||
|
"%s converted shape=%s with dtype=%s",
|
||||||
|
path,
|
||||||
|
weights.shape,
|
||||||
|
target_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
for (path, param), value in tree.flatten_with_path(ckpt):
|
||||||
|
if param == "audio_input_embedding_extra":
|
||||||
|
update_tree("model.embed_audio.embedding.weight", value, config.audio_config.torch_dtype)
|
||||||
|
elif path.endswith("audio_embedding_norm"):
|
||||||
|
update_tree("model.embed_audio.hard_embedding_norm.weight", value, config.audio_config.torch_dtype)
|
||||||
|
elif path.endswith("audio_input_projection"):
|
||||||
|
update_tree(
|
||||||
|
"model.embed_audio.embedding_projection.weight", value.transpose(), config.audio_config.torch_dtype
|
||||||
|
)
|
||||||
|
elif path.endswith("audio_soft_embedding_norm"):
|
||||||
|
update_tree("model.embed_audio.soft_embedding_norm.weight", value, config.audio_config.torch_dtype)
|
||||||
|
elif param == "mm_input_embedding_extra":
|
||||||
|
update_tree("model.embed_vision.embedding.weight", value, config.vision_config.torch_dtype)
|
||||||
|
elif path.endswith("mm_hard_embedding_norm"):
|
||||||
|
update_tree("model.embed_vision.hard_embedding_norm.weight", value, config.vision_config.torch_dtype)
|
||||||
|
elif path.endswith("mm_input_projection"):
|
||||||
|
update_tree(
|
||||||
|
"model.embed_vision.embedding_projection.weight", value.transpose(), config.vision_config.torch_dtype
|
||||||
|
)
|
||||||
|
elif path.endswith("mm_soft_embedding_norm"):
|
||||||
|
update_tree("model.embed_vision.soft_embedding_norm.weight", value, config.vision_config.torch_dtype)
|
||||||
|
elif path.startswith(_TRANSFORMER_PARAMETER):
|
||||||
|
for path, weights in convert_transformer_weights(config.text_config, path, param, value):
|
||||||
|
update_tree(f"model.language_model.{path}", weights, config.text_config.torch_dtype)
|
||||||
|
elif _MOBILE_NET_PREFIX in path:
|
||||||
|
mobilenet_prefix_idx = path.index(_MOBILE_NET_PREFIX)
|
||||||
|
path = path[mobilenet_prefix_idx:]
|
||||||
|
for path, weights in convert_vision_weights(config.vision_config, path, param, value):
|
||||||
|
update_tree(f"model.vision_tower.timm_model.{path}", weights, config.vision_config.torch_dtype)
|
||||||
|
elif path.startswith(_AUDIO_ENCODER_PARAMETER):
|
||||||
|
for path, weights in convert_audio_encoder_weights(config.audio_config, path, param, value):
|
||||||
|
update_tree(f"model.audio_tower.{path}", weights, config.audio_config.torch_dtype)
|
||||||
|
|
||||||
|
hf_tree["lm_head.weight"] = hf_tree["model.language_model.embed_tokens.weight"]
|
||||||
|
|
||||||
|
return hf_tree
|
||||||
|
|
||||||
|
|
||||||
|
def main(*args):
|
||||||
|
del args
|
||||||
|
|
||||||
|
output_path = _OUTPUT_PATH.value
|
||||||
|
variant = _VARIANT.value
|
||||||
|
|
||||||
|
config = _VARIANTS[variant]
|
||||||
|
config.audio_config.torch_dtype = getattr(torch, _AUDIO_DTYPE.value)
|
||||||
|
config.text_config.torch_dtype = getattr(torch, _TRANSFORMER_DTYPE.value)
|
||||||
|
config.vision_config.torch_dtype = getattr(torch, _VISION_DTYPE.value)
|
||||||
|
if _INCLUDE_CHAT_TEMPLATE.value:
|
||||||
|
# Chat template is included for instruction tuned models, which treat
|
||||||
|
# both "<eos>" and "<end_of_turn>" as generation stoppers.
|
||||||
|
config.eos_token_id = [1, 106]
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
"Converting Gemma 3 (%s) @ %s (language) and %s (vision)",
|
||||||
|
variant,
|
||||||
|
_TRANSFORMER_DTYPE.value,
|
||||||
|
_VISION_DTYPE.value,
|
||||||
|
)
|
||||||
|
state_tree = convert(_CHECKPOINT_PATH.value, config)
|
||||||
|
logging.info("Converted Gemma 3 (%s) state tree from Orbax to Hugging Face.", variant)
|
||||||
|
|
||||||
|
with accelerate.init_empty_weights():
|
||||||
|
model = Gemma3nForConditionalGeneration(config=config)
|
||||||
|
|
||||||
|
model.load_state_dict(state_tree, assign=True, strict=True)
|
||||||
|
logging.info(
|
||||||
|
"Loaded Gemma 3 (%s) in Hugging Face Transformers as a %s instance.",
|
||||||
|
variant,
|
||||||
|
type(model).__name__,
|
||||||
|
)
|
||||||
|
model.save_pretrained(output_path, state_dict=state_tree, safe_serialization=True)
|
||||||
|
logging.info(
|
||||||
|
"Saved Gemma 3 (%s) to SafeTensors in %s using %s",
|
||||||
|
variant,
|
||||||
|
output_path,
|
||||||
|
type(model).__name__,
|
||||||
|
)
|
||||||
|
del model
|
||||||
|
del state_tree
|
||||||
|
|
||||||
|
chat_template_kwargs = {"chat_template": _CHAT_TEMPLATE} if _INCLUDE_CHAT_TEMPLATE.value else {}
|
||||||
|
|
||||||
|
tokenizer = GemmaTokenizerFast(
|
||||||
|
_TOKENIZER_PATH.value,
|
||||||
|
add_bos_token=True,
|
||||||
|
extra_special_tokens={
|
||||||
|
"image_token": "<image_soft_token>", # Should be ID=262_145
|
||||||
|
"boi_token": "<start_of_image>", # Should be ID=255_999
|
||||||
|
"eoi_token": "<end_of_image>", # Should be ID=262_144
|
||||||
|
"audio_token": "<audio_soft_token>", # Should be ID=262_273
|
||||||
|
"boa_token": "<start_of_audio>", # Should be ID=256_000
|
||||||
|
"eoa_token": "<end_of_audio>", # Should be ID=262_272
|
||||||
|
},
|
||||||
|
**chat_template_kwargs,
|
||||||
|
)
|
||||||
|
tokenizer.save_pretrained(output_path)
|
||||||
|
logging.info("Saved GemmaTokenizer for %s to %s", variant, output_path)
|
||||||
|
|
||||||
|
feature_extractor = Gemma3nAudioFeatureExtractor()
|
||||||
|
image_processor = SiglipImageProcessorFast(
|
||||||
|
image_seq_length=256,
|
||||||
|
image_mean=(0.5,) * 3,
|
||||||
|
image_std=(0.5,) * 3,
|
||||||
|
size={"height": 768, "width": 768},
|
||||||
|
resample=PILImageResampling.BILINEAR,
|
||||||
|
do_normalize=False,
|
||||||
|
)
|
||||||
|
processor = Gemma3nProcessor(
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
image_processor=image_processor,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
**chat_template_kwargs,
|
||||||
|
)
|
||||||
|
processor.save_pretrained(output_path)
|
||||||
|
|
||||||
|
logging.info("Saved Gemma3nProcessor for %s to %s", variant, output_path)
|
||||||
|
|
||||||
|
# NOTE: feature_extractor and image_processor both use the same filename, preprocessor_config.json, when saved to
|
||||||
|
# disk, but the files are overwritten by processor.save_pretrained(). However, the configs can be unioned, saved,
|
||||||
|
# and loaded from the same preprocessor_config.json file, so we do that explicitly here.
|
||||||
|
feature_extractor_config = json.loads(feature_extractor.to_json_string())
|
||||||
|
image_processor_config = json.loads(image_processor.to_json_string())
|
||||||
|
preprocessor_config = {**feature_extractor_config, **image_processor_config}
|
||||||
|
with open(os.path.join(output_path, "preprocessor_config.json"), "w", encoding="utf-8") as writer:
|
||||||
|
writer.write(json.dumps(preprocessor_config, indent=2, sort_keys=True) + "\n")
|
||||||
|
|
||||||
|
logging.info("Saved joint preprocessor_config.json for %s to %s", variant, output_path)
|
||||||
|
|
||||||
|
del feature_extractor, image_processor, processor, tokenizer
|
||||||
|
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
pad_token_id=config.text_config.pad_token_id,
|
||||||
|
bos_token_id=config.text_config.bos_token_id,
|
||||||
|
eos_token_id=(
|
||||||
|
[config.text_config.eos_token_id, 106] if _INCLUDE_CHAT_TEMPLATE.value else config.text_config.eos_token_id
|
||||||
|
),
|
||||||
|
cache_implementation="hybrid",
|
||||||
|
temperature=1.0,
|
||||||
|
do_sample=True,
|
||||||
|
top_k=64,
|
||||||
|
top_p=0.95,
|
||||||
|
)
|
||||||
|
generation_config.save_pretrained(output_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run(main)
|
||||||
338
src/transformers/models/gemma3n/feature_extraction_gemma3n.py
Normal file
338
src/transformers/models/gemma3n/feature_extraction_gemma3n.py
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||||
|
from ...feature_extraction_utils import BatchFeature
|
||||||
|
from ...utils import PaddingStrategy, TensorType, logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_fb_matrix(
|
||||||
|
n_freqs: int,
|
||||||
|
f_min: float,
|
||||||
|
f_max: float,
|
||||||
|
n_mels: int,
|
||||||
|
sample_rate: int,
|
||||||
|
fft_length: int,
|
||||||
|
norm: Optional[str] = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
r"""Create a frequency bin conversion matrix (NumPy version).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_freqs (int): Number of frequencies to highlight/apply
|
||||||
|
f_min (float): Minimum frequency (Hz)
|
||||||
|
f_max (float): Maximum frequency (Hz)
|
||||||
|
n_mels (int): Number of mel filterbanks
|
||||||
|
sample_rate (int): Sample rate of the audio waveform
|
||||||
|
fft_length (int): FFT length
|
||||||
|
norm (Optional[str]): If 'slaney', divide the triangular mel weights by
|
||||||
|
the width of the mel band (area normalization). (Default: ``None``)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Triangular filter banks (fb matrix) of size (``n_freqs``,
|
||||||
|
``n_mels``)
|
||||||
|
meaning number of frequencies to highlight/apply to x the number of
|
||||||
|
filterbanks.
|
||||||
|
Each column is a filterbank so that assuming there is a matrix A of
|
||||||
|
size (..., ``n_freqs``), the applied result would be
|
||||||
|
``A @ create_fb_matrix_numpy(A.shape[-1], ...)``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if norm is not None and norm != "slaney":
|
||||||
|
raise ValueError("norm must be one of None or 'slaney'")
|
||||||
|
|
||||||
|
# freq bins
|
||||||
|
all_freqs = np.arange(n_freqs, dtype=np.float32) * (sample_rate / fft_length)
|
||||||
|
|
||||||
|
# calculate mel freq bins
|
||||||
|
# hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
|
||||||
|
m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0))
|
||||||
|
m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0))
|
||||||
|
m_pts = np.linspace(m_min, m_max, n_mels + 2)
|
||||||
|
# mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.)
|
||||||
|
f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0)
|
||||||
|
# calculate difference between each mel point and each stft freq point in Hz
|
||||||
|
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
|
||||||
|
slopes = np.expand_dims(f_pts, 0) - np.expand_dims(all_freqs, 1) # (n_freqs, n_mels + 2)
|
||||||
|
# create overlapping triangles
|
||||||
|
zero = np.zeros(1, dtype=np.float32)
|
||||||
|
down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels)
|
||||||
|
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels)
|
||||||
|
fb = np.maximum(zero, np.minimum(down_slopes, up_slopes))
|
||||||
|
|
||||||
|
if norm is not None and norm == "slaney":
|
||||||
|
# Slaney-style mel is scaled to be approx constant energy per channel
|
||||||
|
enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels])
|
||||||
|
fb *= np.expand_dims(enorm, 0)
|
||||||
|
|
||||||
|
return fb
|
||||||
|
|
||||||
|
|
||||||
|
def _unfold(array: np.ndarray, dimension: int, size: int, step: int) -> np.ndarray:
|
||||||
|
"""A basic NumPy equivalent of PyTorch's unfold for 2D arrays along the last dim."""
|
||||||
|
if array.ndim != 2:
|
||||||
|
raise ValueError("This unfold implementation currently supports 2D arrays (batch, time).")
|
||||||
|
if dimension != -1 and dimension != array.ndim - 1:
|
||||||
|
raise ValueError("This unfold implementation only supports unfolding the last dimension.")
|
||||||
|
|
||||||
|
batch_size, original_length = array.shape
|
||||||
|
num_frames = (original_length - size) // step + 1
|
||||||
|
|
||||||
|
if num_frames <= 0:
|
||||||
|
return np.zeros((batch_size, 0, size), dtype=array.dtype)
|
||||||
|
|
||||||
|
output_shape = (batch_size, num_frames, size)
|
||||||
|
output_strides = (array.strides[0], array.strides[1] * step, array.strides[1])
|
||||||
|
|
||||||
|
return np.lib.stride_tricks.as_strided(array, shape=output_shape, strides=output_strides)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3nAudioFeatureExtractor(SequenceFeatureExtractor):
|
||||||
|
"""An audio feature extractor Universal Speech Models https://arxiv.org/abs/2303.01037.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_size (`int`, *optional*, defaults to 128):
|
||||||
|
The feature dimension of the extracted features.
|
||||||
|
sampling_rate (`int`, *optional*, defaults to 16000):
|
||||||
|
The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
|
||||||
|
padding_value (`float`, *optional*, defaults to 0.0):
|
||||||
|
Padding value used to pad the audio. Should correspond to silences.
|
||||||
|
return_attention_mask (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to return the attention mask for the generated MEL spectrograms.
|
||||||
|
frame_length_ms (`float`, *optional*, defaults to 32.0):
|
||||||
|
The length of a frame in milliseconds.
|
||||||
|
hop_length_ms (`float`, *optional*, defaults to 10.0):
|
||||||
|
Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients.
|
||||||
|
min_frequency (`float`, *optional*, defaults to 125.0):
|
||||||
|
The minimum frequency (in Hz) for the Mel filterbank.
|
||||||
|
max_frequency (`float`, *optional*, defaults to 7600.0):
|
||||||
|
The maximum frequency (in Hz) for the Mel filterbank.
|
||||||
|
preemphasis (`float`, *optional*, defaults to 0.97):
|
||||||
|
The preemphasis coefficient.
|
||||||
|
preemphasis_htk_flavor (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use HTK-style preemphasis.
|
||||||
|
fft_overdrive (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use FFT overdrive.
|
||||||
|
dither (`float`, *optional*, defaults to 0.0):
|
||||||
|
Adds dithering. In other words, adds a small Gaussian noise to each frame.
|
||||||
|
E.g. use 0.0001 to add dithering with a normal distribution centered
|
||||||
|
around 0.0 with standard deviation 0.0001 (assuming [-1,+1] range of raw_speech).
|
||||||
|
The value 0.0 means no dithering.
|
||||||
|
Dithering has similar effect as `spectrogram(mel_floor=...)`. It reduces
|
||||||
|
the high log_mel_fbank values for signals with hard-zero sections,
|
||||||
|
when VAD cutoff is present in the signal.
|
||||||
|
input_scale_factor (`float`, *optional*, defaults to 1.0):
|
||||||
|
Scaling factor applied to the input waveform.
|
||||||
|
mel_floor (`float`, *optional*, defaults to 1e-05):
|
||||||
|
Minimum value for Mel spectrograms to avoid log(0).
|
||||||
|
per_bin_mean (`Optional[Sequence[float]]`, *optional*):
|
||||||
|
Mean values for per-bin normalization.
|
||||||
|
per_bin_stddev (`Optional[Sequence[float]]`, *optional*):
|
||||||
|
Standard deviation values for per-bin normalization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["input_features", "input_features_mask"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
feature_size: int = 128,
|
||||||
|
sampling_rate: int = 16_000,
|
||||||
|
padding_value: float = 0.0,
|
||||||
|
return_attention_mask: bool = True,
|
||||||
|
frame_length_ms: float = 32.0,
|
||||||
|
hop_length_ms: float = 10.0,
|
||||||
|
min_frequency: float = 125.0,
|
||||||
|
max_frequency: float = 7600.0,
|
||||||
|
preemphasis: float = 0.97,
|
||||||
|
preemphasis_htk_flavor: bool = True,
|
||||||
|
fft_overdrive: bool = True,
|
||||||
|
dither: float = 0.0,
|
||||||
|
input_scale_factor: float = 1.0,
|
||||||
|
mel_floor: float = 1e-5,
|
||||||
|
per_bin_mean: Optional[Sequence[float]] = None,
|
||||||
|
per_bin_stddev: Optional[Sequence[float]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
feature_size=feature_size,
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
padding_value=padding_value,
|
||||||
|
return_attention_mask=return_attention_mask,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.min_frequency = min_frequency
|
||||||
|
self.max_frequency = max_frequency
|
||||||
|
self.preemphasis = preemphasis
|
||||||
|
self.preemphasis_htk_flavor = preemphasis_htk_flavor
|
||||||
|
self.fft_overdrive = fft_overdrive
|
||||||
|
self.dither = dither
|
||||||
|
self.input_scale_factor = input_scale_factor
|
||||||
|
self.frame_length = int(round(sampling_rate * frame_length_ms / 1000.0))
|
||||||
|
self.hop_length = int(round(sampling_rate * hop_length_ms / 1000.0))
|
||||||
|
self.mel_floor = np.array(mel_floor, dtype=np.float64)
|
||||||
|
|
||||||
|
fft_length = 2 ** math.ceil(math.log2(self.frame_length))
|
||||||
|
if self.fft_overdrive:
|
||||||
|
fft_length *= 2
|
||||||
|
self.fft_length = fft_length
|
||||||
|
|
||||||
|
hann_arange = np.arange(self.frame_length, dtype=np.float32)
|
||||||
|
window = 0.5 * (1 - np.cos(2 * np.pi * hann_arange / self.frame_length))
|
||||||
|
self.window = window.astype(np.float32)
|
||||||
|
|
||||||
|
self.mel_filters = create_fb_matrix(
|
||||||
|
n_freqs=self.fft_length // 2 + 1,
|
||||||
|
f_min=min_frequency,
|
||||||
|
f_max=max_frequency,
|
||||||
|
n_mels=feature_size,
|
||||||
|
sample_rate=self.sampling_rate,
|
||||||
|
norm=None,
|
||||||
|
fft_length=fft_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
if per_bin_mean is not None:
|
||||||
|
self.per_bin_mean = np.array(per_bin_mean).reshape(1, 1, feature_size)
|
||||||
|
else:
|
||||||
|
self.per_bin_mean = None
|
||||||
|
|
||||||
|
if per_bin_stddev is not None:
|
||||||
|
self.per_bin_stddev = np.array(per_bin_stddev).reshape(1, 1, feature_size)
|
||||||
|
else:
|
||||||
|
self.per_bin_stddev = None
|
||||||
|
|
||||||
|
def _extract_spectrogram(self, waveform: np.ndarray, attention_mask: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
""""""
|
||||||
|
if waveform.ndim == 1: # If single waveform, add batch dimension
|
||||||
|
waveform = np.expand_dims(waveform, axis=0)
|
||||||
|
|
||||||
|
if self.dither > 0.0:
|
||||||
|
waveform = waveform + self.dither * np.random.randn(*waveform.shape).astype(waveform.dtype)
|
||||||
|
|
||||||
|
if self.input_scale_factor != 1.0:
|
||||||
|
waveform = waveform * self.input_scale_factor
|
||||||
|
|
||||||
|
frame_size_for_unfold = self.frame_length + 1
|
||||||
|
|
||||||
|
# NumPy equivalent of unfold for [B, NumFrames, frame_size_for_unfold]
|
||||||
|
frames_to_process = _unfold(waveform, dimension=-1, size=frame_size_for_unfold, step=self.hop_length)
|
||||||
|
|
||||||
|
if self.preemphasis > 0.0:
|
||||||
|
if self.preemphasis_htk_flavor:
|
||||||
|
first_in_frame = frames_to_process[..., :1] * (1.0 - self.preemphasis)
|
||||||
|
rest_in_frame = frames_to_process[..., 1:-1] - self.preemphasis * frames_to_process[..., :-2]
|
||||||
|
frames = np.concatenate([first_in_frame, rest_in_frame], axis=-1)
|
||||||
|
else:
|
||||||
|
frames = frames_to_process[..., 1:] - self.preemphasis * frames_to_process[..., :-1]
|
||||||
|
else:
|
||||||
|
frames = frames_to_process[..., :-1]
|
||||||
|
|
||||||
|
frames = frames * self.window # Broadcasting window
|
||||||
|
stft = np.fft.rfft(frames, n=self.fft_length, axis=-1)
|
||||||
|
|
||||||
|
magnitude_spec = np.abs(stft)
|
||||||
|
|
||||||
|
mel_spec = np.matmul(magnitude_spec, self.mel_filters)
|
||||||
|
log_mel_spec = np.log(np.maximum(mel_spec, self.mel_floor))
|
||||||
|
|
||||||
|
if self.per_bin_mean is not None:
|
||||||
|
log_mel_spec = log_mel_spec - self.per_bin_mean # Broadcasting
|
||||||
|
|
||||||
|
if self.per_bin_stddev is not None:
|
||||||
|
log_mel_spec = log_mel_spec / self.per_bin_stddev # Broadcasting
|
||||||
|
|
||||||
|
mel_spectrogram = log_mel_spec.squeeze()
|
||||||
|
mask = attention_mask[:: self.hop_length].astype(bool)
|
||||||
|
# TODO: The filtered mask is always exactly 3 elements longer than the mel_spectrogram. Why???
|
||||||
|
return mel_spectrogram, mask[: mel_spectrogram.shape[0]]
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
|
||||||
|
padding: Union[bool, str, PaddingStrategy] = "longest",
|
||||||
|
max_length: Optional[int] = 480_000,
|
||||||
|
truncation: bool = True,
|
||||||
|
pad_to_multiple_of: Optional[int] = 128,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
return_attention_mask: Optional[bool] = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> BatchFeature:
|
||||||
|
"""Creates a batch of MEL spectrograms from the provided raw speech.
|
||||||
|
|
||||||
|
This implementation uses a different algorithm for windowing and preemphasis compared to the built-in
|
||||||
|
`transformers.audio_utils.spectrogram()` function that _will_ result in different outputs. Consider this
|
||||||
|
carefully when selecting an audio feature extactor, especially with pre-trained models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_speech:
|
||||||
|
The audio for which MEL spectrograms are created.
|
||||||
|
padding (`Union[bool, str, PaddingStrategy]`, *optional*, defaults to `"longest"`):
|
||||||
|
The padding strategy to use for batches of audio with different lengths.
|
||||||
|
max_length (`int`, *optional*, defaults to 480000):
|
||||||
|
If provided, defines the maximum length of the audio to allow. Audio longer than this will be
|
||||||
|
truncated if `truncation=True`.
|
||||||
|
truncation (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to truncate audio above `max_length`.
|
||||||
|
pad_to_multiple_of (`int`, *optional*, defaults to 128):
|
||||||
|
When padding, pad to a multiple of this value. The default value is defined for optimal TPU support.
|
||||||
|
return_tensors (`Union[str, TensorType]`, *optional*, defaults to `None`):
|
||||||
|
The type of tensors to return (e.g., NumPy, Torch, JAX, TensorFlow).
|
||||||
|
return_attention_mask (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to return the attention mask for the generated MEL spectrograms.
|
||||||
|
"""
|
||||||
|
|
||||||
|
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
|
||||||
|
is_batched_sequence = isinstance(raw_speech, Sequence) and isinstance(raw_speech[0], (np.ndarray, Sequence))
|
||||||
|
is_batched = is_batched_numpy or is_batched_sequence
|
||||||
|
|
||||||
|
if is_batched:
|
||||||
|
raw_speech = [np.asarray([rs]).T for rs in raw_speech]
|
||||||
|
elif not is_batched and not isinstance(raw_speech, np.ndarray):
|
||||||
|
raw_speech = np.asarray(raw_speech)
|
||||||
|
|
||||||
|
if not is_batched: # always return a batch
|
||||||
|
raw_speech = [np.asarray([raw_speech])]
|
||||||
|
|
||||||
|
batched_speech = self.pad(
|
||||||
|
BatchFeature({"input_features": raw_speech}),
|
||||||
|
padding=padding,
|
||||||
|
max_length=max_length,
|
||||||
|
truncation=truncation,
|
||||||
|
pad_to_multiple_of=pad_to_multiple_of,
|
||||||
|
return_attention_mask=return_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
prepared_speech = []
|
||||||
|
prepared_speech_mask = []
|
||||||
|
for speech, mask in zip(batched_speech.input_features, batched_speech.attention_mask):
|
||||||
|
speech, mask = self._extract_spectrogram(speech.T, mask)
|
||||||
|
prepared_speech.append(speech.astype(np.float32))
|
||||||
|
prepared_speech_mask.append(mask)
|
||||||
|
|
||||||
|
return BatchFeature(
|
||||||
|
{"input_features": prepared_speech, "input_features_mask": prepared_speech_mask},
|
||||||
|
tensor_type=return_tensors,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Gemma3nAudioFeatureExtractor"]
|
||||||
2422
src/transformers/models/gemma3n/modeling_gemma3n.py
Normal file
2422
src/transformers/models/gemma3n/modeling_gemma3n.py
Normal file
File diff suppressed because it is too large
Load Diff
2664
src/transformers/models/gemma3n/modular_gemma3n.py
Normal file
2664
src/transformers/models/gemma3n/modular_gemma3n.py
Normal file
File diff suppressed because it is too large
Load Diff
191
src/transformers/models/gemma3n/processing_gemma3n.py
Normal file
191
src/transformers/models/gemma3n/processing_gemma3n.py
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ...feature_extraction_utils import BatchFeature
|
||||||
|
from ...image_utils import ImageInput, make_nested_list_of_images
|
||||||
|
from ...processing_utils import AudioKwargs, ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
|
||||||
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3nImagesKwargs(ImagesKwargs):
|
||||||
|
do_pan_and_scan: Optional[bool]
|
||||||
|
pan_and_scan_min_crop_size: Optional[int]
|
||||||
|
pan_and_scan_max_num_crops: Optional[int]
|
||||||
|
pan_and_scan_min_ratio_to_activate: Optional[float]
|
||||||
|
do_convert_rgb: Optional[bool]
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3nProcessorKwargs(ProcessingKwargs, total=False):
|
||||||
|
audio_kwargs: AudioKwargs
|
||||||
|
images_kwargs: Gemma3nImagesKwargs
|
||||||
|
_defaults = {
|
||||||
|
"text_kwargs": {
|
||||||
|
"padding": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3nProcessor(ProcessorMixin):
|
||||||
|
"""
|
||||||
|
A processor for Gemma 3n, wrapping the full capabilities of a feature extractor, image processor, and tokenizer
|
||||||
|
into a single processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_extractor (`Gemma3nAudioFeatureExtractor`):
|
||||||
|
Feature extractor that converts raw audio waveforms into MEL spectrograms for the audio encoder. This
|
||||||
|
should return a `BatchFeature` with `input_features` and `input_features_mask` features.
|
||||||
|
image_processor (`SiglipImageProcessorFast`):
|
||||||
|
Image processor that prepares batches of images for the vision encoder. This should return a `BatchFeature`
|
||||||
|
with a `pixel_values` feature.
|
||||||
|
tokenizer (`GemmaTokenizerFast`):
|
||||||
|
The text tokenizer for the model.
|
||||||
|
chat_template (`string`, *optional*):
|
||||||
|
A Jinja template for generating text prompts from a set of messages.
|
||||||
|
audio_seq_length (int, *optional*, defaults to 188):
|
||||||
|
The number of audio soft tokens that will be added to the text prompt
|
||||||
|
image_seq_length (int, *optional*, defaults to 256):
|
||||||
|
The number of image soft tokens that should be added to
|
||||||
|
"""
|
||||||
|
|
||||||
|
attributes = ["feature_extractor", "image_processor", "tokenizer"]
|
||||||
|
feature_extractor_class = "AutoFeatureExtractor"
|
||||||
|
image_processor_class = "AutoImageProcessor"
|
||||||
|
tokenizer_class = "AutoTokenizer"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
feature_extractor,
|
||||||
|
image_processor,
|
||||||
|
tokenizer,
|
||||||
|
chat_template=None,
|
||||||
|
audio_seq_length: int = 188,
|
||||||
|
image_seq_length: int = 256,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.audio_seq_length = audio_seq_length
|
||||||
|
self.audio_token_id = tokenizer.audio_token_id
|
||||||
|
self.boa_token = tokenizer.boa_token
|
||||||
|
self.audio_token = tokenizer.audio_token
|
||||||
|
audio_tokens_expanded = "".join([tokenizer.audio_token] * audio_seq_length)
|
||||||
|
self.full_audio_sequence = f"\n\n{tokenizer.boa_token}{audio_tokens_expanded}{tokenizer.eoa_token}\n\n"
|
||||||
|
|
||||||
|
self.image_seq_length = image_seq_length
|
||||||
|
self.image_token_id = tokenizer.image_token_id
|
||||||
|
self.boi_token = tokenizer.boi_token
|
||||||
|
self.image_token = tokenizer.image_token
|
||||||
|
image_tokens_expanded = "".join([tokenizer.image_token] * image_seq_length)
|
||||||
|
self.full_image_sequence = f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n"
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
image_processor=image_processor,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
chat_template=chat_template,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
images: ImageInput = None,
|
||||||
|
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
|
||||||
|
audio: Optional[Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]]] = None,
|
||||||
|
videos=None,
|
||||||
|
**kwargs: Unpack[Gemma3nProcessorKwargs],
|
||||||
|
) -> BatchFeature:
|
||||||
|
if text is None and images is None and audio is None:
|
||||||
|
raise ValueError("Provide at least one of `text`, `images`, or `audio`.")
|
||||||
|
|
||||||
|
output_kwargs = self._merge_kwargs(
|
||||||
|
Gemma3nProcessorKwargs,
|
||||||
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(text, str):
|
||||||
|
text = [text]
|
||||||
|
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||||
|
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||||
|
|
||||||
|
if audio is not None:
|
||||||
|
audio_inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
text = [self.audio_token for _ in audio]
|
||||||
|
|
||||||
|
# Expand placeholder audio tokens to the full audio token sequence
|
||||||
|
text = [prompt.replace(self.audio_token, self.full_audio_sequence) for prompt in text]
|
||||||
|
else:
|
||||||
|
audio_inputs = {}
|
||||||
|
|
||||||
|
if images is not None:
|
||||||
|
batched_images = make_nested_list_of_images(images)
|
||||||
|
image_inputs = self.image_processor(batched_images, **output_kwargs["images_kwargs"])
|
||||||
|
|
||||||
|
# Create empty text to be replaced with placeholders
|
||||||
|
if not text:
|
||||||
|
text = [" ".join([self.image_token] * len(images)) for images in batched_images]
|
||||||
|
|
||||||
|
if len(batched_images) != len(text):
|
||||||
|
raise ValueError(
|
||||||
|
f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expand placeholder image tokens to the full image token sequence
|
||||||
|
text = [prompt.replace(self.image_token, self.full_image_sequence) for prompt in text]
|
||||||
|
else:
|
||||||
|
image_inputs = {}
|
||||||
|
|
||||||
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||||
|
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
|
||||||
|
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
|
||||||
|
|
||||||
|
# Add token type ids manually, as tokenizer can't do arbitrary position token types
|
||||||
|
array_ids = text_inputs["input_ids"]
|
||||||
|
token_type_ids = np.zeros_like(array_ids)
|
||||||
|
token_type_ids[array_ids == self.image_token_id] = 1
|
||||||
|
token_type_ids[array_ids == self.audio_token_id] = 3
|
||||||
|
text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs
|
||||||
|
text_inputs["token_type_ids"] = token_type_ids.tolist()
|
||||||
|
return BatchFeature(data={**text_inputs, **image_inputs, **audio_inputs}, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
|
||||||
|
def batch_decode(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||||
|
refer to the docstring of this method for more information.
|
||||||
|
"""
|
||||||
|
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||||
|
|
||||||
|
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma
|
||||||
|
def decode(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||||
|
the docstring of this method for more information.
|
||||||
|
"""
|
||||||
|
return self.tokenizer.decode(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_input_names(self):
|
||||||
|
tokenizer_input_names = self.tokenizer.model_input_names + ["token_type_ids"]
|
||||||
|
image_processor_input_names = self.image_processor.model_input_names
|
||||||
|
feature_extactor_input_names = self.feature_extractor.model_input_names
|
||||||
|
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names + feature_extactor_input_names))
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["Gemma3nProcessor"]
|
||||||
@@ -1642,6 +1642,7 @@ def set_model_tester_for_less_flaky_test(test_case):
|
|||||||
"AriaVisionText2TextModelTester",
|
"AriaVisionText2TextModelTester",
|
||||||
"GPTNeoModelTester",
|
"GPTNeoModelTester",
|
||||||
"DPTModelTester",
|
"DPTModelTester",
|
||||||
|
"Gemma3nTextModelTester", # cannot have a single layer combined with the cache sharing config attrs in the tester
|
||||||
]
|
]
|
||||||
if test_case.model_tester.__class__.__name__ in exceptional_classes:
|
if test_case.model_tester.__class__.__name__ in exceptional_classes:
|
||||||
target_num_hidden_layers = None
|
target_num_hidden_layers = None
|
||||||
|
|||||||
0
tests/models/gemma3n/__init__.py
Normal file
0
tests/models/gemma3n/__init__.py
Normal file
277
tests/models/gemma3n/test_feature_extraction_gemma3n.py
Normal file
277
tests/models/gemma3n/test_feature_extraction_gemma3n.py
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
# Copyright 2025 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
from transformers.models.gemma3n import Gemma3nAudioFeatureExtractor
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
check_json_file_has_correct_format,
|
||||||
|
require_torch,
|
||||||
|
)
|
||||||
|
from transformers.utils.import_utils import is_torch_available
|
||||||
|
|
||||||
|
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
pass
|
||||||
|
|
||||||
|
global_rng = random.Random()
|
||||||
|
|
||||||
|
MAX_LENGTH_FOR_TESTING = 512
|
||||||
|
|
||||||
|
|
||||||
|
def floats_list(shape, scale=1.0, rng=None):
|
||||||
|
"""Creates a random float32 tensor"""
|
||||||
|
if rng is None:
|
||||||
|
rng = global_rng
|
||||||
|
|
||||||
|
values = []
|
||||||
|
for _ in range(shape[0]):
|
||||||
|
values.append([])
|
||||||
|
for _ in range(shape[1]):
|
||||||
|
values[-1].append(rng.random() * scale)
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3nAudioFeatureExtractionTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=7,
|
||||||
|
min_seq_length=400,
|
||||||
|
max_seq_length=2000,
|
||||||
|
feature_size: int = 128,
|
||||||
|
sampling_rate: int = 16_000,
|
||||||
|
padding_value: float = 0.0,
|
||||||
|
return_attention_mask: bool = False,
|
||||||
|
# ignore hop_length / frame_length for now, as ms -> length conversion causes issues with serialization tests
|
||||||
|
# frame_length_ms: float = 32.0,
|
||||||
|
# hop_length: float = 10.0,
|
||||||
|
min_frequency: float = 125.0,
|
||||||
|
max_frequency: float = 7600.0,
|
||||||
|
preemphasis: float = 0.97,
|
||||||
|
preemphasis_htk_flavor: bool = True,
|
||||||
|
fft_overdrive: bool = True,
|
||||||
|
dither: float = 0.0,
|
||||||
|
input_scale_factor: float = 1.0,
|
||||||
|
mel_floor: float = 1e-5,
|
||||||
|
per_bin_mean: Optional[Sequence[float]] = None,
|
||||||
|
per_bin_stddev: Optional[Sequence[float]] = None,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.min_seq_length = min_seq_length
|
||||||
|
self.max_seq_length = max_seq_length
|
||||||
|
self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1)
|
||||||
|
self.feature_size = feature_size
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
|
self.padding_value = padding_value
|
||||||
|
self.return_attention_mask = return_attention_mask
|
||||||
|
# ignore hop_length / frame_length for now, as ms -> length conversion causes issues with serialization tests
|
||||||
|
# self.frame_length_ms = frame_length_ms
|
||||||
|
# self.hop_length = hop_length
|
||||||
|
self.min_frequency = min_frequency
|
||||||
|
self.max_frequency = max_frequency
|
||||||
|
self.preemphasis = preemphasis
|
||||||
|
self.preemphasis_htk_flavor = preemphasis_htk_flavor
|
||||||
|
self.fft_overdrive = fft_overdrive
|
||||||
|
self.dither = dither
|
||||||
|
self.input_scale_factor = input_scale_factor
|
||||||
|
self.mel_floor = mel_floor
|
||||||
|
self.per_bin_mean = per_bin_mean
|
||||||
|
self.per_bin_stddev = per_bin_stddev
|
||||||
|
|
||||||
|
def prepare_feat_extract_dict(self):
|
||||||
|
return {
|
||||||
|
"feature_size": self.feature_size,
|
||||||
|
"sampling_rate": self.sampling_rate,
|
||||||
|
"padding_value": self.padding_value,
|
||||||
|
"return_attention_mask": self.return_attention_mask,
|
||||||
|
"min_frequency": self.min_frequency,
|
||||||
|
"max_frequency": self.max_frequency,
|
||||||
|
"preemphasis": self.preemphasis,
|
||||||
|
"preemphasis_htk_flavor": self.preemphasis_htk_flavor,
|
||||||
|
"fft_overdrive": self.fft_overdrive,
|
||||||
|
"dither": self.dither,
|
||||||
|
"input_scale_factor": self.input_scale_factor,
|
||||||
|
"mel_floor": self.mel_floor,
|
||||||
|
"per_bin_mean": self.per_bin_mean,
|
||||||
|
"per_bin_stddev": self.per_bin_stddev,
|
||||||
|
}
|
||||||
|
|
||||||
|
def prepare_inputs_for_common(self, equal_length=False, numpify=False):
|
||||||
|
def _flatten(list_of_lists):
|
||||||
|
return list(itertools.chain(*list_of_lists))
|
||||||
|
|
||||||
|
if equal_length:
|
||||||
|
speech_inputs = [floats_list((self.max_seq_length, self.feature_size)) for _ in range(self.batch_size)]
|
||||||
|
else:
|
||||||
|
# make sure that inputs increase in size
|
||||||
|
speech_inputs = [
|
||||||
|
floats_list((x, self.feature_size))
|
||||||
|
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
|
||||||
|
]
|
||||||
|
if numpify:
|
||||||
|
speech_inputs = [np.asarray(x) for x in speech_inputs]
|
||||||
|
return speech_inputs
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3nAudioFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
|
||||||
|
feature_extraction_class = Gemma3nAudioFeatureExtractor
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.feat_extract_tester = Gemma3nAudioFeatureExtractionTester(self)
|
||||||
|
|
||||||
|
def test_feat_extract_from_and_save_pretrained(self):
|
||||||
|
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
|
||||||
|
check_json_file_has_correct_format(saved_file)
|
||||||
|
feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
dict_first = feat_extract_first.to_dict()
|
||||||
|
dict_second = feat_extract_second.to_dict()
|
||||||
|
mel_1 = feat_extract_first.mel_filters
|
||||||
|
mel_2 = feat_extract_second.mel_filters
|
||||||
|
self.assertTrue(np.allclose(mel_1, mel_2))
|
||||||
|
self.assertEqual(dict_first, dict_second)
|
||||||
|
|
||||||
|
def test_feat_extract_to_json_file(self):
|
||||||
|
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
json_file_path = os.path.join(tmpdirname, "feat_extract.json")
|
||||||
|
feat_extract_first.to_json_file(json_file_path)
|
||||||
|
feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path)
|
||||||
|
|
||||||
|
dict_first = feat_extract_first.to_dict()
|
||||||
|
dict_second = feat_extract_second.to_dict()
|
||||||
|
mel_1 = feat_extract_first.mel_filters
|
||||||
|
mel_2 = feat_extract_second.mel_filters
|
||||||
|
self.assertTrue(np.allclose(mel_1, mel_2))
|
||||||
|
self.assertEqual(dict_first, dict_second)
|
||||||
|
|
||||||
|
def test_feat_extract_from_pretrained_kwargs(self):
|
||||||
|
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
|
||||||
|
check_json_file_has_correct_format(saved_file)
|
||||||
|
feat_extract_second = self.feature_extraction_class.from_pretrained(
|
||||||
|
tmpdirname, feature_size=2 * self.feat_extract_dict["feature_size"]
|
||||||
|
)
|
||||||
|
|
||||||
|
mel_1 = feat_extract_first.mel_filters
|
||||||
|
mel_2 = feat_extract_second.mel_filters
|
||||||
|
self.assertTrue(2 * mel_1.shape[1] == mel_2.shape[1])
|
||||||
|
|
||||||
|
@parameterized.expand(
|
||||||
|
[
|
||||||
|
([floats_list((1, x))[0] for x in range(800, 1400, 200)],),
|
||||||
|
([floats_list((1, x))[0] for x in (800, 800, 800)],),
|
||||||
|
([floats_list((1, x))[0] for x in range(200, (MAX_LENGTH_FOR_TESTING + 500), 200)], True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_call(self, audio_inputs, test_truncation=False):
|
||||||
|
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||||
|
np_audio_inputs = [np.asarray(audio_input) for audio_input in audio_inputs]
|
||||||
|
|
||||||
|
input_features = feature_extractor(np_audio_inputs, padding="max_length", return_tensors="np").input_features
|
||||||
|
self.assertTrue(input_features.ndim == 3)
|
||||||
|
# input_features.shape should be (batch, num_frames, n_mels) ~= (batch, num_frames, feature_size)
|
||||||
|
# 480_000 is the max_length that inputs are padded to. we use that to calculate num_frames
|
||||||
|
expected_num_frames = (480_000 - feature_extractor.frame_length) // (feature_extractor.hop_length) + 1
|
||||||
|
self.assertTrue(
|
||||||
|
input_features.shape[-2] == expected_num_frames,
|
||||||
|
f"no match: {input_features.shape[-1]} vs {expected_num_frames}",
|
||||||
|
)
|
||||||
|
self.assertTrue(input_features.shape[-1] == feature_extractor.feature_size)
|
||||||
|
|
||||||
|
encoded_sequences_1 = feature_extractor(audio_inputs, return_tensors="np").input_features
|
||||||
|
encoded_sequences_2 = feature_extractor(np_audio_inputs, return_tensors="np").input_features
|
||||||
|
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||||
|
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||||
|
|
||||||
|
if test_truncation:
|
||||||
|
audio_inputs_truncated = [x[:MAX_LENGTH_FOR_TESTING] for x in audio_inputs]
|
||||||
|
np_audio_inputs_truncated = [np.asarray(audio_input) for audio_input in audio_inputs_truncated]
|
||||||
|
|
||||||
|
encoded_sequences_1 = feature_extractor(
|
||||||
|
audio_inputs_truncated, max_length=MAX_LENGTH_FOR_TESTING, return_tensors="np"
|
||||||
|
).input_features
|
||||||
|
encoded_sequences_2 = feature_extractor(
|
||||||
|
np_audio_inputs_truncated, max_length=MAX_LENGTH_FOR_TESTING, return_tensors="np"
|
||||||
|
).input_features
|
||||||
|
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||||
|
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||||
|
|
||||||
|
def test_dither(self):
|
||||||
|
np.random.seed(42) # seed the dithering randn()
|
||||||
|
|
||||||
|
# Tests that features with and without little dithering are similar, but not the same
|
||||||
|
dict_no_dither = self.feat_extract_tester.prepare_feat_extract_dict()
|
||||||
|
dict_no_dither["dither"] = 0.0
|
||||||
|
|
||||||
|
dict_dither = self.feat_extract_tester.prepare_feat_extract_dict()
|
||||||
|
dict_dither["dither"] = 0.00003 # approx. 1/32k
|
||||||
|
|
||||||
|
feature_extractor_no_dither = self.feature_extraction_class(**dict_no_dither)
|
||||||
|
feature_extractor_dither = self.feature_extraction_class(**dict_dither)
|
||||||
|
|
||||||
|
# create three inputs of length 800, 1000, and 1200
|
||||||
|
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||||
|
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
|
||||||
|
|
||||||
|
# compute features
|
||||||
|
input_features_no_dither = feature_extractor_no_dither(
|
||||||
|
np_speech_inputs, padding=True, return_tensors="np", sampling_rate=dict_no_dither["sampling_rate"]
|
||||||
|
).input_features
|
||||||
|
input_features_dither = feature_extractor_dither(
|
||||||
|
np_speech_inputs, padding=True, return_tensors="np", sampling_rate=dict_dither["sampling_rate"]
|
||||||
|
).input_features
|
||||||
|
|
||||||
|
# test there is a difference between features (there's added noise to input signal)
|
||||||
|
diff = input_features_dither - input_features_no_dither
|
||||||
|
|
||||||
|
# features are not identical
|
||||||
|
self.assertTrue(np.abs(diff).mean() > 1e-6)
|
||||||
|
# features are not too different
|
||||||
|
self.assertTrue(np.abs(diff).mean() <= 1e-4)
|
||||||
|
self.assertTrue(np.abs(diff).max() <= 5e-3)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_double_precision_pad(self):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||||
|
np_speech_inputs = np.random.rand(100, 32).astype(np.float64)
|
||||||
|
py_speech_inputs = np_speech_inputs.tolist()
|
||||||
|
|
||||||
|
for inputs in [py_speech_inputs, np_speech_inputs]:
|
||||||
|
np_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="np")
|
||||||
|
self.assertTrue(np_processed.input_features.dtype == np.float32)
|
||||||
|
pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
|
||||||
|
self.assertTrue(pt_processed.input_features.dtype == torch.float32)
|
||||||
886
tests/models/gemma3n/test_modeling_gemma3n.py
Normal file
886
tests/models/gemma3n/test_modeling_gemma3n.py
Normal file
@@ -0,0 +1,886 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Testing suite for the PyTorch Gemma3n model."""
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from datasets import load_dataset
|
||||||
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoProcessor,
|
||||||
|
AutoTokenizer,
|
||||||
|
Gemma3nAudioConfig,
|
||||||
|
Gemma3nAudioFeatureExtractor,
|
||||||
|
Gemma3nConfig,
|
||||||
|
Gemma3nTextConfig,
|
||||||
|
GenerationConfig,
|
||||||
|
is_torch_available,
|
||||||
|
)
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
cleanup,
|
||||||
|
require_flash_attn,
|
||||||
|
require_read_token,
|
||||||
|
require_torch,
|
||||||
|
require_torch_gpu,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
|
from ...test_configuration_common import ConfigTester
|
||||||
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||||
|
from ..gemma.test_modeling_gemma import GemmaModelTester
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
Gemma3nAudioEncoder,
|
||||||
|
Gemma3nForCausalLM,
|
||||||
|
Gemma3nForConditionalGeneration,
|
||||||
|
Gemma3nModel,
|
||||||
|
Gemma3nTextModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3nAudioModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=2,
|
||||||
|
num_channels=32, # feature_size / input_feat_size
|
||||||
|
sampling_rate=16_000,
|
||||||
|
raw_audio_length=8_000,
|
||||||
|
is_training=True,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
|
self.raw_audio_length = raw_audio_length
|
||||||
|
self.is_training = is_training
|
||||||
|
|
||||||
|
def get_feature_extractor_config(self):
|
||||||
|
return {
|
||||||
|
"feature_size": self.num_channels,
|
||||||
|
"sampling_rate": self.sampling_rate,
|
||||||
|
"padding_value": 0.0,
|
||||||
|
"return_attention_mask": True,
|
||||||
|
"frame_length_ms": 32.0,
|
||||||
|
"hop_length_ms": 10.0,
|
||||||
|
"dither": 0.0, # Important for determinism
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_audio_encoder_config(self):
|
||||||
|
return Gemma3nAudioConfig(
|
||||||
|
input_feat_size=self.num_channels,
|
||||||
|
hidden_size=32,
|
||||||
|
conf_num_attention_heads=4,
|
||||||
|
conf_num_hidden_layers=2,
|
||||||
|
sscp_conv_channel_size=(16, 8),
|
||||||
|
conf_conv_kernel_size=3,
|
||||||
|
conf_attention_chunk_size=4,
|
||||||
|
conf_attention_context_left=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
# Prepare inputs for the audio encoder
|
||||||
|
feature_extractor_config = self.get_feature_extractor_config()
|
||||||
|
audio_encoder_config = self.get_audio_encoder_config()
|
||||||
|
|
||||||
|
np.random.seed(0)
|
||||||
|
raw_speech_1 = np.sin(2 * np.pi * 440 * np.linspace(0, 1, self.raw_audio_length)).astype(np.float32)
|
||||||
|
raw_speech_2 = np.random.randn(self.raw_audio_length // 2).astype(np.float32)
|
||||||
|
raw_speech = [raw_speech_1, raw_speech_2]
|
||||||
|
|
||||||
|
feature_extractor = Gemma3nAudioFeatureExtractor(**feature_extractor_config)
|
||||||
|
audio_inputs = feature_extractor(raw_speech, return_tensors="pt")
|
||||||
|
|
||||||
|
input_features = audio_inputs["input_features"]
|
||||||
|
# The encoder expects a padding mask (True for padding), while the feature extractor
|
||||||
|
# returns an attention mask (True for valid tokens). We must invert it.
|
||||||
|
input_features_mask = ~audio_inputs["input_features_mask"].to(torch.bool)
|
||||||
|
|
||||||
|
inputs_dict = {
|
||||||
|
"audio_mel": input_features,
|
||||||
|
"audio_mel_mask": input_features_mask,
|
||||||
|
}
|
||||||
|
return audio_encoder_config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skip("Skipped for now!")
|
||||||
|
@require_torch
|
||||||
|
class Gemma3nAudioModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (Gemma3nAudioEncoder,) if is_torch_available() else ()
|
||||||
|
test_pruning = False
|
||||||
|
test_head_masking = False
|
||||||
|
test_missing_keys = False
|
||||||
|
is_generative = False
|
||||||
|
_is_stateful = True
|
||||||
|
main_input_name = "audio_mel"
|
||||||
|
test_initialization = False
|
||||||
|
test_can_init_all_missing_weights = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = Gemma3nAudioModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=Gemma3nAudioConfig, hidden_size=37)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
# The following values are golden outputs from a deterministic run of the components.
|
||||||
|
# They are used to ensure that changes to the code do not alter the numerical output.
|
||||||
|
# Generated with seeds np.random.seed(0) and torch.manual_seed(0).
|
||||||
|
self.expected_input_features_shape = (2, 48, 32)
|
||||||
|
self.expected_input_features_slice = np.array([-5.733152, -5.337127, -4.916284, -4.378989, -3.7622747])
|
||||||
|
self.expected_input_features_mask_shape = (2, 48)
|
||||||
|
self.expected_input_features_mask_slice = np.array([True, True, True, True, False])
|
||||||
|
|
||||||
|
self.expected_encoder_output_shape = (2, 3, 32)
|
||||||
|
self.expected_encoder_output_slice = torch.tensor([-0.4159, 0.6459, 0.6305, 2.2902, 0.9683])
|
||||||
|
self.expected_encoder_mask_shape = (2, 3)
|
||||||
|
self.expected_encoder_mask_slice = torch.tensor([False, False, True])
|
||||||
|
|
||||||
|
# Prepare a shared feature extractor and raw audio for the tests
|
||||||
|
self.feature_extractor = Gemma3nAudioFeatureExtractor(**self.model_tester.get_feature_extractor_config())
|
||||||
|
np.random.seed(0)
|
||||||
|
raw_speech_1 = np.sin(2 * np.pi * 440 * np.linspace(0, 1, self.model_tester.raw_audio_length)).astype(
|
||||||
|
np.float32
|
||||||
|
)
|
||||||
|
raw_speech_2 = np.random.randn(self.model_tester.raw_audio_length // 2).astype(np.float32)
|
||||||
|
self.raw_speech = [raw_speech_1, raw_speech_2]
|
||||||
|
|
||||||
|
@unittest.skip("Audio encoder does not support attention output")
|
||||||
|
def test_attention_outputs(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Audio encoder does not support hidden state output")
|
||||||
|
def test_hidden_states_output(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Audio encoder returns a tuple, not a ModelOutput object, skipping equivalence test.")
|
||||||
|
def test_model_outputs_equivalence(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Audio encoder does not support retaining gradients on hidden states/attentions.")
|
||||||
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Audio encoder does not have a concept of token embeddings")
|
||||||
|
def test_model_get_set_embeddings(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Audio encoder does not have a concept of token embeddings")
|
||||||
|
def test_resize_tokens_embeddings(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("This model has a complex downsampling scheme that is hard to test with the generic batching test.")
|
||||||
|
def test_batching_equivalence(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_feature_extractor(self):
|
||||||
|
"""
|
||||||
|
Tests the feature extractor's output against pre-computed golden values.
|
||||||
|
This ensures the NumPy-based audio preprocessing is correct and consistent.
|
||||||
|
"""
|
||||||
|
audio_inputs = self.feature_extractor(
|
||||||
|
self.raw_speech, padding="longest", pad_to_multiple_of=128, return_tensors="np"
|
||||||
|
)
|
||||||
|
|
||||||
|
input_features = audio_inputs["input_features"]
|
||||||
|
self.assertEqual(input_features.shape, self.expected_input_features_shape)
|
||||||
|
np.testing.assert_allclose(input_features[0, 0, :5], self.expected_input_features_slice, rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
|
print(input_features[0, 0, :5])
|
||||||
|
|
||||||
|
input_features_mask = audio_inputs["input_features_mask"]
|
||||||
|
self.assertEqual(input_features_mask.shape, self.expected_input_features_mask_shape)
|
||||||
|
# The second audio sample is shorter (22 frames vs 48), so its mask should become False at index 22
|
||||||
|
np.testing.assert_array_equal(input_features_mask[1, 21:26], self.expected_input_features_mask_slice)
|
||||||
|
|
||||||
|
def test_audio_encoder(self):
|
||||||
|
"""
|
||||||
|
Tests the audio encoder's forward pass against pre-computed golden values.
|
||||||
|
This ensures the PyTorch-based audio encoding model is correct and consistent.
|
||||||
|
"""
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
model = Gemma3nAudioEncoder(config).to(torch_device).eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
encoder_output, encoder_mask = model(**inputs_dict)
|
||||||
|
|
||||||
|
print(encoder_output[0, 0, :5])
|
||||||
|
|
||||||
|
# Check output encodings
|
||||||
|
self.assertEqual(encoder_output.shape, self.expected_encoder_output_shape)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
encoder_output[0, 0, :5], self.expected_encoder_output_slice.to(torch_device), rtol=1e-4, atol=1e-4
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check output mask (True means padded)
|
||||||
|
# Second sample has 22 feature frames. After downsampling by 4 (conv) -> 5 frames. After downsampling by 4 (reduction) -> 1 frame.
|
||||||
|
# So the mask should be [False, True, True]
|
||||||
|
self.assertEqual(encoder_mask.shape, self.expected_encoder_mask_shape)
|
||||||
|
torch.testing.assert_close(encoder_mask[1, :], self.expected_encoder_mask_slice.to(torch_device))
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3nTextModelTester(GemmaModelTester):
|
||||||
|
activation_sparsity_pattern = None
|
||||||
|
forced_config_args = ["activation_sparsity_pattern"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=13,
|
||||||
|
seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
use_input_mask=True,
|
||||||
|
use_token_type_ids=False,
|
||||||
|
use_labels=True,
|
||||||
|
vocab_size=99,
|
||||||
|
vocab_size_per_layer_input=99,
|
||||||
|
hidden_size=16,
|
||||||
|
num_hidden_layers=4, # override to correctly test sharing cache pattern
|
||||||
|
num_kv_shared_layers=2, # important to override
|
||||||
|
layer_types=[
|
||||||
|
"full_attention",
|
||||||
|
"sliding_attention",
|
||||||
|
"full_attention",
|
||||||
|
"sliding_attention",
|
||||||
|
], # similarly we want to test sharing on both types
|
||||||
|
num_attention_heads=2,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
altup_num_inputs=2,
|
||||||
|
intermediate_size=21,
|
||||||
|
hidden_activation="gelu_pytorch_tanh",
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=16,
|
||||||
|
type_sequence_label_size=2,
|
||||||
|
initializer_range=0.02,
|
||||||
|
num_labels=3,
|
||||||
|
num_choices=4,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
is_decoder=False,
|
||||||
|
):
|
||||||
|
self._verify_model_attributes()
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_input_mask = use_input_mask
|
||||||
|
self.use_token_type_ids = use_token_type_ids
|
||||||
|
self.use_labels = use_labels
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.vocab_size_per_layer_input = vocab_size_per_layer_input
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_kv_shared_layers = num_kv_shared_layers
|
||||||
|
self.layer_types = layer_types
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.altup_num_inputs = altup_num_inputs
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_activation = hidden_activation
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.type_vocab_size = type_vocab_size
|
||||||
|
self.type_sequence_label_size = type_sequence_label_size
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.num_choices = num_choices
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.head_dim = self.hidden_size // self.num_attention_heads
|
||||||
|
self.is_decoder = is_decoder
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
config_class = Gemma3nTextConfig
|
||||||
|
model_class = Gemma3nTextModel
|
||||||
|
for_causal_lm_class = Gemma3nForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skip("Skipped for now!")
|
||||||
|
@require_torch
|
||||||
|
class Gemma3nTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (Gemma3nTextModel, Gemma3nForCausalLM) if is_torch_available() else ()
|
||||||
|
all_generative_model_classes = (Gemma3nForCausalLM,) if is_torch_available() else ()
|
||||||
|
test_headmasking = False
|
||||||
|
test_pruning = False
|
||||||
|
_is_stateful = True
|
||||||
|
model_split_percents = [0.5, 0.6]
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = Gemma3nTextModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(
|
||||||
|
self,
|
||||||
|
config_class=Gemma3nConfig,
|
||||||
|
hidden_size=37,
|
||||||
|
text_config={"activation_sparsity_pattern": None},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_hidden_states_for_generate(
|
||||||
|
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
|
||||||
|
):
|
||||||
|
"Gemma3n has special hidden states shape with 1 additional dim (which is then reduced with projections)"
|
||||||
|
|
||||||
|
self.assertIsInstance(hidden_states, tuple)
|
||||||
|
self.assertListEqual(
|
||||||
|
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
|
||||||
|
[True] * len(hidden_states),
|
||||||
|
)
|
||||||
|
self.assertEqual(len(hidden_states), (output_length - prompt_length))
|
||||||
|
|
||||||
|
# When `output_hidden_states=True`, each iteration of generate appends the hidden states corresponding to the
|
||||||
|
# new token(s)
|
||||||
|
# NOTE: `HybridCache` may have different lengths on different layers, if this test starts failing add more
|
||||||
|
# elaborate checks
|
||||||
|
for generated_length, iter_hidden_states in enumerate(hidden_states):
|
||||||
|
# regardless of using cache, the first forward pass will have the full prompt as input
|
||||||
|
if use_cache and generated_length > 0:
|
||||||
|
model_input_length = 1
|
||||||
|
else:
|
||||||
|
model_input_length = prompt_length + generated_length
|
||||||
|
expected_shape = (config.altup_num_inputs, batch_size, model_input_length, config.hidden_size)
|
||||||
|
# check hidden size
|
||||||
|
self.assertListEqual(
|
||||||
|
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
||||||
|
[expected_shape] * len(iter_hidden_states),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma3nVision2TextModelTester:
|
||||||
|
text_config = {"activation_sparsity_pattern": None}
|
||||||
|
forced_config_args = ["text_config"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
mm_tokens_per_image=2,
|
||||||
|
image_token_index=1,
|
||||||
|
boi_token_index=2,
|
||||||
|
eoi_token_index=3,
|
||||||
|
seq_length=25,
|
||||||
|
is_training=True,
|
||||||
|
vision_config={
|
||||||
|
"use_labels": True,
|
||||||
|
"image_size": 20,
|
||||||
|
"patch_size": 5,
|
||||||
|
"num_channels": 3,
|
||||||
|
"is_training": True,
|
||||||
|
"hidden_size": 32,
|
||||||
|
"num_key_value_heads": 1,
|
||||||
|
"num_hidden_layers": 2,
|
||||||
|
"num_attention_heads": 4,
|
||||||
|
"intermediate_size": 37,
|
||||||
|
"dropout": 0.1,
|
||||||
|
"attention_dropout": 0.1,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
},
|
||||||
|
use_cache=False,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
# `image_token_index` is set to 0 to pass "resize_embeddings" test, do not modify
|
||||||
|
self.mm_tokens_per_image = mm_tokens_per_image
|
||||||
|
self.image_token_index = image_token_index
|
||||||
|
self.boi_token_index = boi_token_index
|
||||||
|
self.eoi_token_index = eoi_token_index
|
||||||
|
self.llm_tester = Gemma3nTextModelTester(self.parent)
|
||||||
|
self.text_config = self.llm_tester.get_config()
|
||||||
|
self.vision_config = vision_config
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.pad_token_id = self.text_config.pad_token_id
|
||||||
|
|
||||||
|
self.num_hidden_layers = self.text_config.num_hidden_layers
|
||||||
|
self.vocab_size = self.text_config.vocab_size
|
||||||
|
self.hidden_size = self.text_config.hidden_size
|
||||||
|
self.num_attention_heads = self.text_config.num_attention_heads
|
||||||
|
self.is_training = is_training
|
||||||
|
|
||||||
|
self.batch_size = 3
|
||||||
|
self.num_channels = vision_config["num_channels"]
|
||||||
|
self.image_size = vision_config["image_size"]
|
||||||
|
self.encoder_seq_length = seq_length
|
||||||
|
self.use_cache = use_cache
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return Gemma3nConfig(
|
||||||
|
text_config=self.text_config,
|
||||||
|
vision_config=self.vision_config,
|
||||||
|
image_token_index=self.image_token_index,
|
||||||
|
boi_token_index=self.boi_token_index,
|
||||||
|
eoi_token_index=self.eoi_token_index,
|
||||||
|
mm_tokens_per_image=self.mm_tokens_per_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
pixel_values = floats_tensor(
|
||||||
|
[
|
||||||
|
self.batch_size,
|
||||||
|
self.vision_config["num_channels"],
|
||||||
|
self.vision_config["image_size"],
|
||||||
|
self.vision_config["image_size"],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
return config, pixel_values
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
config, pixel_values = config_and_inputs
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
|
||||||
|
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
|
||||||
|
|
||||||
|
# set the 3 first tokens to be image, and ensure that no other tokens are image tokens
|
||||||
|
# do not change this unless you modified image size or patch size
|
||||||
|
input_ids[input_ids == config.image_token_index] = self.pad_token_id
|
||||||
|
input_ids[:, :1] = config.image_token_index
|
||||||
|
|
||||||
|
token_type_ids = torch.zeros_like(input_ids)
|
||||||
|
token_type_ids[input_ids == config.image_token_index] = 1
|
||||||
|
|
||||||
|
inputs_dict = {
|
||||||
|
"pixel_values": pixel_values,
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"token_type_ids": token_type_ids,
|
||||||
|
}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skip("Skipped for now!")
|
||||||
|
@require_torch
|
||||||
|
class Gemma3nVision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (Gemma3nModel, Gemma3nForConditionalGeneration) if is_torch_available() else ()
|
||||||
|
all_generative_model_classes = (Gemma3nForConditionalGeneration,) if is_torch_available() else ()
|
||||||
|
test_headmasking = False
|
||||||
|
test_pruning = False
|
||||||
|
test_missing_keys = False
|
||||||
|
_is_stateful = True
|
||||||
|
model_split_percents = [0.5, 0.6]
|
||||||
|
|
||||||
|
# MP works but offload doesn't work when the SigLIP MultiheadAttention is offloaded
|
||||||
|
# TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"]
|
||||||
|
# in the dispatch_model function
|
||||||
|
test_cpu_offload = False
|
||||||
|
test_disk_offload_safetensors = False
|
||||||
|
test_disk_offload_bin = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = Gemma3nVision2TextModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(
|
||||||
|
self,
|
||||||
|
config_class=Gemma3nConfig,
|
||||||
|
hidden_size=37,
|
||||||
|
text_config={"activation_sparsity_pattern": None},
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
|
||||||
|
def test_training_gradient_checkpointing(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
|
||||||
|
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
|
||||||
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
reason="HybridCache can't be gathered because it is not iterable. Adding a simple iter and dumping `distributed_iterator`"
|
||||||
|
" as in Dynamic Cache doesnt work. NOTE: @gante all cache objects would need better compatibility with multi gpu setting"
|
||||||
|
)
|
||||||
|
def test_multi_gpu_data_parallel_forward(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Failing because of unique cache (HybridCache)")
|
||||||
|
def test_model_outputs_equivalence(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@parameterized.expand([("random",), ("same",)])
|
||||||
|
@pytest.mark.generate
|
||||||
|
@unittest.skip("Gemma3n has HybridCache which is not compatible with assisted decoding")
|
||||||
|
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Gemma3n has HybridCache which is not compatible with assisted decoding")
|
||||||
|
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
|
@unittest.skip("Gemma3n has HybridCache which is not compatible with assisted decoding")
|
||||||
|
def test_assisted_decoding_sample(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Gemma3n has HybridCache which is not compatible with dola decoding")
|
||||||
|
def test_dola_decoding_sample(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Gemma3n has HybridCache and doesn't support continue from past kv")
|
||||||
|
def test_generate_continue_from_past_key_values(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Gemma3n has HybridCache and doesn't support low_memory generation")
|
||||||
|
def test_beam_search_low_memory(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Gemma3n has HybridCache and doesn't support contrastive generation")
|
||||||
|
def test_contrastive_generate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Gemma3n has HybridCache and doesn't support contrastive generation")
|
||||||
|
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Gemma3n has HybridCache and doesn't support contrastive generation")
|
||||||
|
def test_contrastive_generate_low_memory(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Gemma3n has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
|
||||||
|
def test_generate_with_static_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Gemma3n has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
|
||||||
|
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
reason="Siglip (vision backbone) uses the same initialization scheme as the Flax original implementation"
|
||||||
|
)
|
||||||
|
def test_initialization(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
reason="Siglip has no FLEX attention, and we don't have a proper way to set/test attn in VLMs. TODO @raushan"
|
||||||
|
)
|
||||||
|
def test_flex_attention_with_grads(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_automodelforcausallm(self):
|
||||||
|
"""
|
||||||
|
Regression test for #36741 -- make sure `AutoModelForCausalLM` works with a Gemma3n config, i.e. that
|
||||||
|
`AutoModelForCausalLM.from_pretrained` pulls the text config before loading the model
|
||||||
|
"""
|
||||||
|
config = self.model_tester.get_config()
|
||||||
|
model = Gemma3nForConditionalGeneration(config)
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
for_causal_lm = AutoModelForCausalLM.from_pretrained(tmp_dir)
|
||||||
|
self.assertIsInstance(for_causal_lm, Gemma3nForCausalLM)
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skip("Skipped for now!")
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_read_token
|
||||||
|
class Gemma3nIntegrationTest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.processor = AutoProcessor.from_pretrained("Google/gemma-3n-E4B-it", padding_side="left")
|
||||||
|
|
||||||
|
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
|
||||||
|
self.messages = [
|
||||||
|
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "url": url},
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
audio_ds = load_dataset(
|
||||||
|
"etechgrid/28.5k_wavfiles_dataset", "default", data_files="wav_dataset/103-1240-0000.wav"
|
||||||
|
)
|
||||||
|
self.audio_file_path = audio_ds["train"][0]["audio"]["path"]
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
cleanup(torch_device, gc_collect=True)
|
||||||
|
|
||||||
|
def test_model_4b_bf16(self):
|
||||||
|
model_id = "Google/gemma-3n-E4B-it"
|
||||||
|
|
||||||
|
model = Gemma3nForConditionalGeneration.from_pretrained(
|
||||||
|
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
inputs = self.processor.apply_chat_template(
|
||||||
|
self.messages,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
add_generation_prompt=True,
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
||||||
|
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||||
|
|
||||||
|
EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear blue water and a blue sky in the background. It looks like'] # fmt: skip
|
||||||
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
|
def test_model_with_audio(self):
|
||||||
|
"""
|
||||||
|
Tests the full model pipeline with batched audio inputs provided as file paths.
|
||||||
|
This ensures the processor correctly loads and processes audio files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_id = "Google/gemma-3n-E4B-it"
|
||||||
|
|
||||||
|
model = Gemma3nForConditionalGeneration.from_pretrained(
|
||||||
|
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Transcribe the following speech segment in English:"},
|
||||||
|
{"type": "audio", "audio": str(self.audio_file_path)},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
inputs = self.processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(torch_device, dtype=model.dtype)
|
||||||
|
|
||||||
|
input_len = inputs["input_ids"].shape[-1]
|
||||||
|
|
||||||
|
output = model.generate(**inputs, max_new_tokens=16, do_sample=False)
|
||||||
|
output = output[:, input_len:]
|
||||||
|
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||||
|
|
||||||
|
EXPECTED_TEXTS = ["Chapter 1. Mrs. Rachel Lind is surprised.\n\nMrs. Rachel Lind"]
|
||||||
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
|
def test_model_4b_batch(self):
|
||||||
|
model_id = "Google/gemma-3n-E4B-it"
|
||||||
|
|
||||||
|
model = Gemma3nForConditionalGeneration.from_pretrained(
|
||||||
|
model_id, low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
messages_2 = [
|
||||||
|
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
|
||||||
|
},
|
||||||
|
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||||
|
{"type": "text", "text": "Are these images identical?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
inputs = self.processor.apply_chat_template(
|
||||||
|
[self.messages, messages_2],
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
||||||
|
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||||
|
|
||||||
|
EXPECTED_TEXTS = [
|
||||||
|
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like',
|
||||||
|
"user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a cow"
|
||||||
|
] # fmt: skip
|
||||||
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
|
def test_model_4b_crops(self):
|
||||||
|
model_id = "Google/gemma-3n-E4B-it"
|
||||||
|
|
||||||
|
model = Gemma3nForConditionalGeneration.from_pretrained(
|
||||||
|
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
crop_config = {
|
||||||
|
"images_kwargs": {
|
||||||
|
"do_pan_and_scan": True,
|
||||||
|
"pan_and_scan_max_num_crops": 448,
|
||||||
|
"pan_and_scan_min_crop_size": 32,
|
||||||
|
"pan_and_scan_min_ratio_to_activate": 0.3,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inputs = self.processor.apply_chat_template(
|
||||||
|
self.messages,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
add_generation_prompt=True,
|
||||||
|
**crop_config,
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
||||||
|
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||||
|
|
||||||
|
EXPECTED_NUM_IMAGES = 3 # one for the origin image and two crops of images
|
||||||
|
EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a beach with a turquoise ocean and blue sky in the background.'] # fmt: skip
|
||||||
|
self.assertEqual(len(inputs["pixel_values"]), EXPECTED_NUM_IMAGES)
|
||||||
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
|
def test_model_4b_multiimage(self):
|
||||||
|
model_id = "Google/gemma-3n-E4B-it"
|
||||||
|
|
||||||
|
model = Gemma3nForConditionalGeneration.from_pretrained(
|
||||||
|
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
|
||||||
|
{"type": "text", "text": "What do you see here?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
inputs = self.processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
||||||
|
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||||
|
|
||||||
|
EXPECTED_TEXTS = ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Overall Scene:**\n\nIt looks like a street scene in a vibrant,"] # fmt: skip
|
||||||
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
|
def test_model_1b_text_only(self):
|
||||||
|
model_id = "google/gemma-3-1b-it"
|
||||||
|
|
||||||
|
model = Gemma3nForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
|
||||||
|
inputs = tokenizer("Write a poem about Machine Learning.", return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
||||||
|
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||||
|
|
||||||
|
EXPECTED_TEXTS = ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a river deep,\nWith patterns hidden, secrets sleep.\nA neural net, a watchful eye,\nLearning'] # fmt: skip
|
||||||
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
|
# TODO: raushan FA2 generates gibberish for no reason, check later
|
||||||
|
@require_flash_attn
|
||||||
|
@require_torch_gpu
|
||||||
|
@pytest.mark.flash_attn_test
|
||||||
|
def test_model_4b_flash_attn(self):
|
||||||
|
model_id = "Google/gemma-3n-E4B-it"
|
||||||
|
|
||||||
|
model = Gemma3nForConditionalGeneration.from_pretrained(
|
||||||
|
model_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
inputs = self.processor.apply_chat_template(
|
||||||
|
self.messages,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
add_generation_prompt=True,
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
|
||||||
|
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
|
||||||
|
|
||||||
|
EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean. It looks like a very sunny and'] # fmt: skip
|
||||||
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
|
@parameterized.expand([("flash_attention_2",), ("sdpa",), ("eager",)])
|
||||||
|
def test_generation_beyond_sliding_window(self, attn_implementation: str):
|
||||||
|
"""Test that we can correctly generate beyond the sliding window. This is non trivial as
|
||||||
|
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
|
||||||
|
Outputs for every attention functions should be coherent and identical.
|
||||||
|
"""
|
||||||
|
model_id = "google/gemma-3-1b-it"
|
||||||
|
|
||||||
|
input_text = [
|
||||||
|
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
|
||||||
|
"A list of colors: red, blue", # This will almost all be padding tokens
|
||||||
|
]
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
|
||||||
|
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
# Make sure prefill is larger than sliding window
|
||||||
|
input_size = inputs.input_ids.shape[-1]
|
||||||
|
self.assertTrue(input_size > model.config.sliding_window)
|
||||||
|
|
||||||
|
out = model.generate(**inputs, max_new_tokens=20)[:, input_size:]
|
||||||
|
output_text = tokenizer.batch_decode(out)
|
||||||
|
|
||||||
|
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
|
||||||
|
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
|
||||||
|
|
||||||
|
def test_generation_beyond_sliding_window_with_generation_config(self):
|
||||||
|
"""
|
||||||
|
Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 --
|
||||||
|
ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
|
||||||
|
"""
|
||||||
|
model_id = "google/gemma-3-1b-it"
|
||||||
|
attn_implementation = "sdpa"
|
||||||
|
|
||||||
|
input_text = [
|
||||||
|
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
|
||||||
|
"A list of colors: red, blue", # This will almost all be padding tokens
|
||||||
|
]
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
|
||||||
|
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
# Make sure prefill is larger than sliding window
|
||||||
|
input_size = inputs.input_ids.shape[-1]
|
||||||
|
self.assertTrue(input_size > model.config.sliding_window)
|
||||||
|
|
||||||
|
generation_config = GenerationConfig(max_new_tokens=20)
|
||||||
|
|
||||||
|
out = model.generate(**inputs, generation_config=generation_config)[:, input_size:]
|
||||||
|
output_text = tokenizer.batch_decode(out)
|
||||||
|
|
||||||
|
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
|
||||||
|
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
|
||||||
185
tests/models/gemma3n/test_processing_gemma3n.py
Normal file
185
tests/models/gemma3n/test_processing_gemma3n.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
from transformers import GemmaTokenizerFast, SiglipImageProcessorFast, is_speech_available
|
||||||
|
from transformers.testing_utils import require_sentencepiece, require_torch, require_torchaudio, require_vision
|
||||||
|
|
||||||
|
from .test_feature_extraction_gemma3n import floats_list
|
||||||
|
|
||||||
|
|
||||||
|
if is_speech_available():
|
||||||
|
from transformers.models.gemma3n import Gemma3nAudioFeatureExtractor, Gemma3nProcessor
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_torchaudio
|
||||||
|
@require_vision
|
||||||
|
@require_sentencepiece
|
||||||
|
class Gemma3nProcessorTest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
# TODO: update to google?
|
||||||
|
self.model_id = "Google/gemma-3n-E4B-it"
|
||||||
|
self.tmpdirname = tempfile.mkdtemp(suffix="gemma3n")
|
||||||
|
self.maxDiff = None
|
||||||
|
|
||||||
|
def get_tokenizer(self, **kwargs):
|
||||||
|
return GemmaTokenizerFast.from_pretrained(self.model_id, **kwargs)
|
||||||
|
|
||||||
|
def get_feature_extractor(self, **kwargs):
|
||||||
|
return Gemma3nAudioFeatureExtractor.from_pretrained(self.model_id, **kwargs)
|
||||||
|
|
||||||
|
def get_image_processor(self, **kwargs):
|
||||||
|
return SiglipImageProcessorFast.from_pretrained(self.model_id, **kwargs)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
shutil.rmtree(self.tmpdirname)
|
||||||
|
|
||||||
|
def test_save_load_pretrained_default(self):
|
||||||
|
# NOTE: feature_extractor and image_processor both use the same filename, preprocessor_config.json, when saved to
|
||||||
|
# disk, but the files are overwritten by processor.save_pretrained(). This test does not attempt to address
|
||||||
|
# this potential issue, and as such, does not guarantee content accuracy.
|
||||||
|
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
feature_extractor = self.get_feature_extractor()
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
|
||||||
|
processor = Gemma3nProcessor(
|
||||||
|
tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
|
||||||
|
)
|
||||||
|
|
||||||
|
processor.save_pretrained(self.tmpdirname)
|
||||||
|
processor = Gemma3nProcessor.from_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
self.assertIsInstance(processor.tokenizer, GemmaTokenizerFast)
|
||||||
|
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
|
||||||
|
|
||||||
|
self.assertIsInstance(processor.feature_extractor, Gemma3nAudioFeatureExtractor)
|
||||||
|
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
|
||||||
|
|
||||||
|
def test_save_load_pretrained_additional_features(self):
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
feature_extractor = self.get_feature_extractor()
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
|
||||||
|
processor = Gemma3nProcessor(
|
||||||
|
tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
|
||||||
|
)
|
||||||
|
processor.save_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS-BOS)", eos_token="(EOS-EOS)")
|
||||||
|
feature_extractor_add_kwargs = self.get_feature_extractor(dither=5.0, padding_value=1.0)
|
||||||
|
|
||||||
|
processor = Gemma3nProcessor.from_pretrained(
|
||||||
|
self.tmpdirname, bos_token="(BOS-BOS)", eos_token="(EOS-EOS)", dither=5.0, padding_value=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
|
||||||
|
self.assertIsInstance(processor.tokenizer, GemmaTokenizerFast)
|
||||||
|
|
||||||
|
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
||||||
|
self.assertIsInstance(processor.feature_extractor, Gemma3nAudioFeatureExtractor)
|
||||||
|
|
||||||
|
@parameterized.expand([256, 512, 768, 1024])
|
||||||
|
def test_image_processor(self, image_size: int):
|
||||||
|
feature_extractor = self.get_feature_extractor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
processor = Gemma3nProcessor(
|
||||||
|
tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_image = np.random.randint(0, 256, size=(image_size, image_size, 3), dtype=np.uint8)
|
||||||
|
input_image_processor = image_processor(raw_image, return_tensors="pt")
|
||||||
|
input_processor = processor(text="Describe:", images=raw_image, return_tensors="pt")
|
||||||
|
|
||||||
|
for key in input_image_processor.keys():
|
||||||
|
self.assertAlmostEqual(input_image_processor[key].sum(), input_processor[key].sum(), delta=1e-2)
|
||||||
|
if "pixel_values" in key:
|
||||||
|
# NOTE: all images should be re-scaled to 768x768
|
||||||
|
self.assertEqual(input_image_processor[key].shape, (1, 3, 768, 768))
|
||||||
|
self.assertEqual(input_processor[key].shape, (1, 3, 768, 768))
|
||||||
|
|
||||||
|
def test_audio_feature_extractor(self):
|
||||||
|
feature_extractor = self.get_feature_extractor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
processor = Gemma3nProcessor(
|
||||||
|
tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_speech = floats_list((3, 1000))
|
||||||
|
input_feat_extract = feature_extractor(raw_speech, return_tensors="pt")
|
||||||
|
input_processor = processor(text="Transcribe:", audio=raw_speech, return_tensors="pt")
|
||||||
|
|
||||||
|
for key in input_feat_extract.keys():
|
||||||
|
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
||||||
|
|
||||||
|
def test_tokenizer(self):
|
||||||
|
feature_extractor = self.get_feature_extractor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
processor = Gemma3nProcessor(
|
||||||
|
tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
|
||||||
|
)
|
||||||
|
|
||||||
|
input_str = "This is a test string"
|
||||||
|
|
||||||
|
encoded_processor = processor(text=input_str)
|
||||||
|
|
||||||
|
encoded_tok = tokenizer(input_str)
|
||||||
|
|
||||||
|
for key in encoded_tok.keys():
|
||||||
|
self.assertListEqual(encoded_tok[key], encoded_processor[key][0])
|
||||||
|
|
||||||
|
def test_tokenizer_decode(self):
|
||||||
|
feature_extractor = self.get_feature_extractor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
processor = Gemma3nProcessor(
|
||||||
|
tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
|
||||||
|
)
|
||||||
|
|
||||||
|
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
|
||||||
|
|
||||||
|
decoded_processor = processor.batch_decode(predicted_ids)
|
||||||
|
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
||||||
|
|
||||||
|
self.assertListEqual(decoded_tok, decoded_processor)
|
||||||
|
|
||||||
|
def test_model_input_names(self):
|
||||||
|
feature_extractor = self.get_feature_extractor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
processor = Gemma3nProcessor(
|
||||||
|
tokenizer=tokenizer, feature_extractor=feature_extractor, image_processor=image_processor
|
||||||
|
)
|
||||||
|
|
||||||
|
for key in feature_extractor.model_input_names:
|
||||||
|
self.assertIn(
|
||||||
|
key,
|
||||||
|
processor.model_input_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
for key in image_processor.model_input_names:
|
||||||
|
self.assertIn(
|
||||||
|
key,
|
||||||
|
processor.model_input_names,
|
||||||
|
)
|
||||||
@@ -277,6 +277,7 @@ SPECIAL_CASES_TO_ALLOW = {
|
|||||||
],
|
],
|
||||||
"Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"],
|
"Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"],
|
||||||
"SmolLM3Config": ["no_rope_layer_interval"],
|
"SmolLM3Config": ["no_rope_layer_interval"],
|
||||||
|
"Gemma3nVisionConfig": ["architecture", "do_pooling", "model_args"], # this is for use in `timm`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ ALWAYS_OVERRIDE = ["labels"]
|
|||||||
# docstrings instead. If formatting should be ignored for the docstring, you can put a comment # no-format on the
|
# docstrings instead. If formatting should be ignored for the docstring, you can put a comment # no-format on the
|
||||||
# line before the docstring.
|
# line before the docstring.
|
||||||
OBJECTS_TO_IGNORE = [
|
OBJECTS_TO_IGNORE = [
|
||||||
|
"Gemma3nVisionConfig",
|
||||||
"Llama4Processor",
|
"Llama4Processor",
|
||||||
# Deprecated
|
# Deprecated
|
||||||
"InputExample",
|
"InputExample",
|
||||||
|
|||||||
Reference in New Issue
Block a user