From 3ef889690649c082849c667be17b757c32955229 Mon Sep 17 00:00:00 2001 From: Biao Zhang <17406686+bzhangGo@users.noreply.github.com> Date: Wed, 25 Jun 2025 05:05:10 -0400 Subject: [PATCH] Encoder-Decoder Gemma (#38332) * Initial submit * Fix bugs: 1. add __init__ file 2. tied word embedding 3. support flash/flex attention 4. model saving and loading * Code refactor: * Rename encdecgemma to t5gemma. * Split attention into self- and cross-attention * Split stack into encoder and decoder * Add test cases * Add auto configuration * Update configurations. * Fix bugs related to copy and attribute checks * Fix type union * Fix merge errors * run ruff format * Run make style and update tests. * Add t5gemma model doc. * ruff and style formatting. * Add missed module config. * Add dummy checkpoint link to pass tests (need updated when real checkpoints are uplioaded.). * Update model doc. * Minor updates following Arthur's comments: * replace docstrings with auto_docstrings * remove checkpoint layers * remove deprecate_kwargs * fix rebase errors * Fix docstring issues. * fix t5gemma doc issue. * run ruff format * Updates: * split encoder-only model out * make t5gemmamodel encoder-decoder only * update token and sequence classification * update tests --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/t5gemma.md | 107 ++ src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 7 + .../models/auto/tokenization_auto.py | 7 + src/transformers/models/t5gemma/__init__.py | 27 + .../models/t5gemma/configuration_t5gemma.py | 333 ++++ .../models/t5gemma/modeling_t5gemma.py | 1506 +++++++++++++++ .../models/t5gemma/modular_t5gemma.py | 1455 ++++++++++++++ tests/models/t5gemma/__init__.py | 0 tests/models/t5gemma/test_modeling_t5gemma.py | 1701 +++++++++++++++++ 12 files changed, 5148 insertions(+) create mode 100644 docs/source/en/model_doc/t5gemma.md create mode 100644 src/transformers/models/t5gemma/__init__.py create mode 100644 src/transformers/models/t5gemma/configuration_t5gemma.py create mode 100644 src/transformers/models/t5gemma/modeling_t5gemma.py create mode 100644 src/transformers/models/t5gemma/modular_t5gemma.py create mode 100644 tests/models/t5gemma/__init__.py create mode 100644 tests/models/t5gemma/test_modeling_t5gemma.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 1e6b01759f..bf089a0f6a 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -655,6 +655,8 @@ title: SwitchTransformers - local: model_doc/t5 title: T5 + - local: model_doc/t5gemma + title: T5Gemma - local: model_doc/t5v1.1 title: T5v1.1 - local: model_doc/tapex diff --git a/docs/source/en/model_doc/t5gemma.md b/docs/source/en/model_doc/t5gemma.md new file mode 100644 index 0000000000..d8615a9add --- /dev/null +++ b/docs/source/en/model_doc/t5gemma.md @@ -0,0 +1,107 @@ + + + + +# T5Gemma + +T5Gemma (aka encoder-decoder Gemma) was proposed in a [research paper](https://arxiv.org/abs/2504.06225) by Google. It is a family of encoder-decoder large langauge models, developed by adapting pretrained decoder-only models into encoder-decoder. T5Gemma includes pretrained and instruction-tuned variants. The architecture is based on transformer encoder-decoder design following T5, with improvements from Gemma 2: GQA, RoPE, GeGLU activation, RMSNorm, and interleaved local/global attention. + +T5Gemma has two groups of model sizes: 1) [Gemma 2](https://ai.google.dev/gemma/docs/core/model_card_2) sizes (2B-2B, 9B-2B, and 9B-9B), which are based on the offical Gemma 2 models (2B and 9B); and 2) [T5](https://arxiv.org/abs/1910.10683) sizes (Small, Base, Large, and XL), where are pretrained under the Gemma 2 framework following T5 configuration. In addition, we also provide a model at ML size (medium large, ~2B in total), which is in-between T5 Large and T5 XL. + +The pretrained varaints are trained with two objectives: prefix language modeling with knowledge distillation (PrefixLM) and UL2, separately. We release both variants for each model size. The instruction-turned varaints was post-trained with supervised fine-tuning and reinforcement learning. + +The example below demonstrates how to chat with the model with [`Pipeline`] or the [`AutoModel`] class, and from the command line. + + + + + +```python +import torch +from transformers import pipeline + +pipe = pipeline( + task="text2text-generation", + model="google/t5gemma-placeholder", + torch_dtype=torch.bfloat16, + device="cuda", +) + +pipe("Question: Why is the sky blue?\nAnswer:", max_new_tokens=50) +``` + + + + +```python +import torch +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM + +tokenizer = AutoTokenizer.from_pretrained("google/t5gemma-placeholder") +model = AutoModelForSeq2SeqLM.from_pretrained( + "google/t5gemma-placeholder", + torch_dtype=torch.bfloat16, + device_map="auto" +) + +input_text = "Question: Why is the sky blue?\nAnswer:" +input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") + +outputs = model.generate(**input_ids, max_new_tokens=32) +print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + +``` + + + + +``` +echo -e "Question: Why is the sky blue? Answer:" | transformers run --task text2text-generation --model google/t5gemma-placeholder --device 0 +``` + +## T5GemmaConfig + +[[autodoc]] T5GemmaConfig + +## T5GemmaModuleConfig + +[[autodoc]] T5GemmaModuleConfig + +## T5GemmaModel + +[[autodoc]] T5GemmaModel + - forward + +## T5GemmaEncoderModel + +[[autodoc]] T5GemmaEncoderModel + - forward + +## T5GemmaForConditionalGeneration + +[[autodoc]] T5GemmaForConditionalGeneration + - forward + +## T5GemmaForSequenceClassification + +[[autodoc]] T5GemmaForSequenceClassification + - forward + +## T5GemmaForTokenClassification + +[[autodoc]] T5GemmaForTokenClassification + - forward diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 8d36068353..c53fdfc7a3 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -294,6 +294,7 @@ if TYPE_CHECKING: from .swinv2 import * from .switch_transformers import * from .t5 import * + from .t5gemma import * from .table_transformer import * from .tapas import * from .textnet import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 3812712bed..3758e237e2 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -333,6 +333,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str]( ("swinv2", "Swinv2Config"), ("switch_transformers", "SwitchTransformersConfig"), ("t5", "T5Config"), + ("t5gemma", "T5GemmaConfig"), ("table-transformer", "TableTransformerConfig"), ("tapas", "TapasConfig"), ("textnet", "TextNetConfig"), @@ -721,6 +722,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str]( ("swinv2", "Swin Transformer V2"), ("switch_transformers", "SwitchTransformers"), ("t5", "T5"), + ("t5gemma", "T5Gemma"), ("t5v1.1", "T5v1.1"), ("table-transformer", "Table Transformer"), ("tapas", "TAPAS"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index f2ccf21f58..935eb8fe8a 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -310,6 +310,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("swinv2", "Swinv2Model"), ("switch_transformers", "SwitchTransformersModel"), ("t5", "T5Model"), + ("t5gemma", "T5GemmaModel"), ("table-transformer", "TableTransformerModel"), ("tapas", "TapasModel"), ("textnet", "TextNetModel"), @@ -430,6 +431,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ("squeezebert", "SqueezeBertForMaskedLM"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), + ("t5gemma", "T5GemmaForConditionalGeneration"), ("tapas", "TapasForMaskedLM"), ("transfo-xl", "TransfoXLLMHeadModel"), ("tvlt", "TvltForPreTraining"), @@ -524,6 +526,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( ("squeezebert", "SqueezeBertForMaskedLM"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), + ("t5gemma", "T5GemmaForConditionalGeneration"), ("tapas", "TapasForMaskedLM"), ("transfo-xl", "TransfoXLLMHeadModel"), ("wav2vec2", "Wav2Vec2ForMaskedLM"), @@ -1044,6 +1047,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("seamless_m4t_v2", "SeamlessM4Tv2ForTextToText"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), ("t5", "T5ForConditionalGeneration"), + ("t5gemma", "T5GemmaForConditionalGeneration"), ("umt5", "UMT5ForConditionalGeneration"), ("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"), ] @@ -1156,6 +1160,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("stablelm", "StableLmForSequenceClassification"), ("starcoder2", "Starcoder2ForSequenceClassification"), ("t5", "T5ForSequenceClassification"), + ("t5gemma", "T5GemmaForSequenceClassification"), ("tapas", "TapasForSequenceClassification"), ("transfo-xl", "TransfoXLForSequenceClassification"), ("umt5", "UMT5ForSequenceClassification"), @@ -1349,6 +1354,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("stablelm", "StableLmForTokenClassification"), ("starcoder2", "Starcoder2ForTokenClassification"), ("t5", "T5ForTokenClassification"), + ("t5gemma", "T5GemmaForTokenClassification"), ("umt5", "UMT5ForTokenClassification"), ("xlm", "XLMForTokenClassification"), ("xlm-roberta", "XLMRobertaForTokenClassification"), @@ -1582,6 +1588,7 @@ MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict( ("roformer", "RoFormerModel"), ("squeezebert", "SqueezeBertModel"), ("t5", "T5EncoderModel"), + ("t5gemma", "T5GemmaEncoderModel"), ("umt5", "UMT5EncoderModel"), ("xlm", "XLMModel"), ("xlm-roberta", "XLMRobertaModel"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 4112d111e1..50a1a2732c 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -582,6 +582,13 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]]( "T5TokenizerFast" if is_tokenizers_available() else None, ), ), + ( + "t5gemma", + ( + "GemmaTokenizer" if is_sentencepiece_available() else None, + "GemmaTokenizerFast" if is_tokenizers_available() else None, + ), + ), ("tapas", ("TapasTokenizer", None)), ("tapex", ("TapexTokenizer", None)), ("transfo-xl", ("TransfoXLTokenizer", None)), diff --git a/src/transformers/models/t5gemma/__init__.py b/src/transformers/models/t5gemma/__init__.py new file mode 100644 index 0000000000..aa8099e267 --- /dev/null +++ b/src/transformers/models/t5gemma/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_encdecgemma2 import * + from .modeling_encdecgemma2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/t5gemma/configuration_t5gemma.py b/src/transformers/models/t5gemma/configuration_t5gemma.py new file mode 100644 index 0000000000..b3aa23d0be --- /dev/null +++ b/src/transformers/models/t5gemma/configuration_t5gemma.py @@ -0,0 +1,333 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/t5gemma/modular_t5gemma.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_t5gemma.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Union + +from ...configuration_utils import PretrainedConfig, layer_type_validation + + +class T5GemmaModuleConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`T5GemmaModuleModel`]. It is used to instantiate an T5GemmaModule + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the T5GemmaModule-7B. + e.g. [google/t5_gemma_module-7b](https://huggingface.co/google/t5_gemma_module-7b) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the T5GemmaModule model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`T5GemmaModuleModel`] + hidden_size (`int`, *optional*, defaults to 2304): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 9216): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 26): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"` + if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + query_pre_attn_scalar (`float`, *optional*, defaults to 256): + scaling factor used on the attention scores + sliding_window (`int`, *optional*, defaults to 4096): + in T5GemmaModule, every other layer uses sliding window attention. This is the size of the sliding window. + layer_types (`list`, *optional*): + Attention pattern for each layer. + final_logit_softcapping (`float`, *optional*, defaults to 30.0): + scaling factor when applying tanh softcapping on the logits. + attn_logit_softcapping (`float`, *optional*, defaults to 50.0): + scaling factor when applying tanh softcapping on the attention scores. + + ```python + >>> from transformers import T5GemmaModuleModel, T5GemmaModuleConfig + >>> # Initializing a T5GemmaModule t5_gemma_module-7b style configuration + >>> configuration = T5GemmaModuleConfig() + >>> # Initializing a model from the t5_gemma_module-7b style configuration + >>> model = T5GemmaModuleModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + Module config (encoder or decoder): the same as Gemma2Config.""" + + model_type = "t5_gemma_module" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=256000, + hidden_size=2304, + intermediate_size=9216, + num_hidden_layers=26, + num_attention_heads=8, + num_key_value_heads=4, + head_dim=256, + hidden_activation="gelu_pytorch_tanh", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + tie_word_embeddings=True, + rope_theta=10000.0, + attention_bias=False, + attention_dropout=0.0, + query_pre_attn_scalar=256, + sliding_window=4096, + layer_types=None, + final_logit_softcapping=30.0, + attn_logit_softcapping=50.0, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.hidden_activation = hidden_activation + self.query_pre_attn_scalar = query_pre_attn_scalar + self.sliding_window = sliding_window + self.final_logit_softcapping = final_logit_softcapping + self.attn_logit_softcapping = attn_logit_softcapping + self.layer_types = layer_types + + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types) + + +class T5GemmaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`T5GemmaModel`]. It is used to instantiate an T5Gemma + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to a hypothetical balanced Gemma2 encoder-decoder model. + e.g. [google/t5gemma-placeholder](https://huggingface.co/google/t5gemma-placeholder) + ```python + >>> from transformers import T5GemmaConfig, T5GemmaModel + >>> t5gemma_config = T5GemmaConfig.from_pretrained("google/t5gemma-placeholder") + >>> model = T5GemmaModel(t5gemma_config) + ``` + Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the + documentation from [PretrainedConfig] for more information. + Args: + encoder (`Union[T5GemmaModuleConfig, dict]`, optional, *optional*): + Configuration for the encoder. + decoder (`Union[T5GemmaModuleConfig, dict]`, optional, *optional*): + Configuration for the decoder. + is_encoder_decoder (bool, optional, *optional*, defaults to `True`): + Whether the model is used as an encoder/decoder or not. + dropout_rate (`float`, *optional*, defaults to 0.0): + The ratio for all dropout layers (following T5). + classifier_dropout_rate (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier (following T5). + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for attention. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether tie input and output embeddings. + kwargs (additional keyword arguments, optional, *optional*): + Will be passed to the PretrainedConfig base class. + """ + + model_type = "t5gemma" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + # encoder + "encoder.layers.*.self_attn.q_proj": "colwise", + "encoder.layers.*.self_attn.k_proj": "colwise", + "encoder.layers.*.self_attn.v_proj": "colwise", + "encoder.layers.*.self_attn.o_proj": "rowwise", + "encoder.layers.*.mlp.gate_proj": "colwise", + "encoder.layers.*.mlp.up_proj": "colwise", + "encoder.layers.*.mlp.down_proj": "rowwise", + # decoder + "decoder.layers.*.self_attn.q_proj": "colwise", + "decoder.layers.*.self_attn.k_proj": "colwise", + "decoder.layers.*.self_attn.v_proj": "colwise", + "decoder.layers.*.self_attn.o_proj": "rowwise", + "decoder.layers.*.cross_attn.q_proj": "colwise", + "decoder.layers.*.cross_attn.k_proj": "colwise", + "decoder.layers.*.cross_attn.v_proj": "colwise", + "decoder.layers.*.cross_attn.o_proj": "rowwise", + "decoder.layers.*.mlp.gate_proj": "colwise", + "decoder.layers.*.mlp.up_proj": "colwise", + "decoder.layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + # encoder + "encoder.embed_tokens": (["input_ids"], ["inputs_embeds"]), + "encoder.layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "encoder.norm": (["hidden_states"], ["hidden_states"]), + # decoder + "decoder.embed_tokens": (["input_ids"], ["inputs_embeds"]), + "decoder.layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "decoder.norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + encoder: Optional[Union[T5GemmaModuleConfig, dict[Any, Any]]] = None, + decoder: Optional[Union[T5GemmaModuleConfig, dict[Any, Any]]] = None, + is_encoder_decoder: bool = True, + dropout_rate: float = 0.0, + classifier_dropout_rate: float = 0.0, + attention_dropout: float = 0.0, + tie_word_embeddings: bool = True, + **kwargs, + ): + # Encoder. + if isinstance(encoder, dict): + # From preset configuration + encoder = T5GemmaModuleConfig(**encoder) + elif encoder is None: + # From scratch + encoder = T5GemmaModuleConfig() + else: + assert isinstance(encoder, T5GemmaModuleConfig), f"{type(encoder)} is not supported." + + # Decoder. + if isinstance(decoder, dict): + # From preset configuration + decoder = T5GemmaModuleConfig(**decoder) + elif decoder is None: + # From scratch + decoder = encoder + else: + assert isinstance(decoder, T5GemmaModuleConfig), f"{type(decoder)} is not supported." + + # Decouple encoder and decoder config in any case + encoder = T5GemmaModuleConfig(**encoder.to_dict()) + decoder = T5GemmaModuleConfig(**decoder.to_dict()) + + encoder.is_decoder = False + encoder.dropout_rate = dropout_rate + encoder.attention_dropout = attention_dropout + self.encoder = encoder + + decoder.is_decoder = True + decoder.use_cache = True + decoder.dropout_rate = dropout_rate + decoder.attention_dropout = attention_dropout + decoder.cross_attention_hidden_size = encoder.hidden_size + self.decoder = decoder + + for special_token_key in ["bos_token_id", "pad_token_id", "eos_token_id"]: + if special_token_key not in kwargs: + kwargs[special_token_key] = getattr(decoder, special_token_key) + + super().__init__(**kwargs) + + self.is_encoder_decoder = is_encoder_decoder + self.use_cache = kwargs.get("use_cache", decoder.use_cache) + self.initializer_range = kwargs.get("initializer_range", decoder.initializer_range) + self.dropout_rate = dropout_rate + self.attention_dropout = attention_dropout + self.classifier_dropout_rate = classifier_dropout_rate + self.tie_word_embeddings = tie_word_embeddings + + def __setattr__(self, key, value): + shared_attr_with_submodules = [ + "output_hidden_states", + "output_attentions", + "_attn_implementation", + "dropout_rate", + "attention_dropout", + ] + + if key in shared_attr_with_submodules: + setattr(self.encoder, key, value) + setattr(self.decoder, key, value) + super().__setattr__(key, value) + + def get_text_config(self, decoder=False) -> "PretrainedConfig": + # Always return self, regardless of the decoder option. + del decoder + return self + + +__all__ = ["T5GemmaConfig", "T5GemmaModuleConfig"] diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py new file mode 100644 index 0000000000..7f3ce0927a --- /dev/null +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -0,0 +1,1506 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/t5gemma/modular_t5gemma.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_t5gemma.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Optional, Union + +import torch +import torch.nn as nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, can_return_tuple, logging +from .configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig + + +logger = logging.get_logger(__name__) + + +class T5GemmaRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst T5Gemma is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class T5GemmaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_activation] + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, x): + hidden_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + hidden_states = self.dropout(hidden_states) + down_proj = self.down_proj(hidden_states) + return down_proj + + +class T5GemmaRotaryEmbedding(nn.Module): + def __init__(self, config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + softcap: Optional[float] = None, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + + if softcap is not None: + attn_weights = attn_weights / softcap + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * softcap + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class T5GemmaSelfAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: T5GemmaModuleConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = config.query_pre_attn_scalar**-0.5 + self.attention_dropout = self.config.attention_dropout + # Requied by flash attention: encoder selfattention is non-causal + self.is_causal = config.is_decoder + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.attn_logit_softcapping = self.config.attn_logit_softcapping + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, + softcap=self.attn_logit_softcapping, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class T5GemmaCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: T5GemmaModuleConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = config.query_pre_attn_scalar**-0.5 + self.attention_dropout = self.config.attention_dropout + + # Requied by flash attention + self.is_causal = False + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + + self.k_proj = nn.Linear( + config.cross_attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.cross_attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.attn_logit_softcapping = self.config.attn_logit_softcapping + + if config.cross_attention_hidden_size is None: + raise ValueError("Cross-attention needs cross_attention_hidden_size to be specified.") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + encoder_hidden_states: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + if encoder_hidden_states is None: + raise ValueError("Encoder hidden state is required for cross attention.") + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + # [batch, q_len, -1, head_dim] => [batch, -1, q_len, head_dim] + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + + # conditions for calculating key and value states + if ( + # no cache + past_key_value is None + # cross-attention but not cached yet + or not is_updated + ): + encoder_input_shape = encoder_hidden_states.shape[:-1] + encoder_hidden_shape = (*encoder_input_shape, -1, self.head_dim) + # [batch, kv_len, -1, head_dim] => [batch, -1, kv_len, head_dim] + key_states = self.k_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2) + value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2) + + # update cache + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + past_key_value.is_updated[self.layer_idx] = True + # cross-attention: reuse cached states + else: + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=None, + softcap=self.attn_logit_softcapping, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class T5GemmaEncoderLayer(GradientCheckpointingLayer): + """Encoder sub-layer.""" + + def __init__(self, config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.config = config + self.layer_idx = layer_idx + self.attention_type = config.layer_types[layer_idx] + + # self attention + self.self_attn = T5GemmaSelfAttention( + config=config, + layer_idx=layer_idx, + ) + self.pre_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # mlp + self.mlp = T5GemmaMLP(config) + self.pre_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # dropout + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> tuple[ + torch.FloatTensor, + Optional[tuple[torch.FloatTensor, torch.FloatTensor]], + ]: + # Self Attention + residual = hidden_states + hidden_states = self.pre_self_attn_layernorm(hidden_states) + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + # Remove all caches for encoders. + use_cache=False, + past_key_value=None, + **kwargs, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + # Mlp + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class T5GemmaDecoderLayer(T5GemmaEncoderLayer): + """Decoder sub-layer: an extra cross-attention layer.""" + + def __init__(self, config, layer_idx: int): + super().__init__(config, layer_idx) + # cross attention + self.cross_attn = T5GemmaCrossAttention(config=config, layer_idx=layer_idx) + self.pre_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[EncoderDecoderCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[ + torch.FloatTensor, + Optional[tuple[torch.FloatTensor, torch.FloatTensor]], + Optional[tuple[torch.FloatTensor, torch.FloatTensor]], + ]: + # Self Attention + residual = hidden_states + hidden_states = self.pre_self_attn_layernorm(hidden_states) + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value.self_attention_cache if past_key_value is not None else None, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + # Cross Attention + residual = hidden_states + hidden_states = self.pre_cross_attn_layernorm(hidden_states) + hidden_states, cross_attn_weights = self.cross_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + hidden_states = self.post_cross_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + # Mlp + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +class T5GemmaClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, hidden_size: int, num_labels: int, classifier_dropout_rate: float = 0.0): + super().__init__() + self.dropout = nn.Dropout(p=classifier_dropout_rate) + self.out_proj = nn.Linear(hidden_size, num_labels) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class T5GemmaLMHead(nn.Module): + """Head for language modeling (generation) tasks.""" + + def __init__(self, hidden_size: int, vocab_size: int, bias: bool = False): + super().__init__() + self.out_proj = nn.Linear(hidden_size, vocab_size, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.out_proj(hidden_states) + return logits + + +@auto_docstring +class T5GemmaPreTrainedModel(PreTrainedModel): + config_class = T5GemmaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["T5GemmaBlock"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + # TODO: support intialization for encoders and decoders separately(?) + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, T5GemmaRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, T5GemmaClassificationHead): + scale = module.out_proj.weight.shape[0] ** -0.5 + module.out_proj.weight.data.normal_(mean=0.0, std=std * scale) + if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: + module.out_proj.bias.data.zero_() + elif isinstance(module, T5GemmaLMHead): + if not self.config.tie_word_embeddings: + scale = module.out_proj.weight.shape[0] ** -0.5 + module.out_proj.weight.data.normal_(mean=0.0, std=std * scale) + + def _shift_right(self, input_ids): + """ + Shifts input_ids to the right, prepends the decoder_start_token_id, and handles + pad_token_id replacement for labels that were -100. + This is a common preparation step for decoder inputs in sequence-to-sequence models. + """ + decoder_start_token_id = self.config.decoder.bos_token_id + pad_token_id = self.config.decoder.pad_token_id + + if decoder_start_token_id is None: + raise ValueError("self.model.config.decoder.bos_token_id has to be defined. ") + + # shift inputs to the right + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.decoder.pad_token_id has to be defined.") + + # Is this T5 specific? + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def bidirectional_mask_function(attention_mask: Optional[torch.Tensor]) -> Callable: + """ + This creates bidirectional attention mask. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + # if attention mask is not given, all attention positions are considered valid. + if attention_mask is None: + return torch.ones((), dtype=torch.bool) + # attention_mask: [batch_size, kv_len] + return attention_mask[batch_idx, kv_idx].to(torch.bool) + + return inner_mask + + +def sliding_window_bidirectional_mask_function(sliding_window: int) -> Callable: + """ + This creates bidirectional attention mask with sliding window. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + return (q_idx - sliding_window < kv_idx) & (kv_idx < q_idx + sliding_window) + + return inner_mask + + +def make_default_2d_attention_mask( + token_ids: Optional[torch.LongTensor], + hidden_states: torch.Tensor, + pad_token_id: Optional[int], +) -> torch.Tensor: + """Construct the default attention mask.""" + if token_ids is not None: + if pad_token_id is None: + raise ValueError("`pad_token_id` is required for padding information.") + attention_mask = (token_ids != pad_token_id).to(hidden_states.device, torch.long) + else: + attention_mask = torch.ones( + (hidden_states.shape[0], hidden_states.shape[1]), device=hidden_states.device, dtype=torch.long + ) + return attention_mask + + +class T5GemmaEncoder(T5GemmaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.norm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = T5GemmaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + self.layers = nn.ModuleList( + [T5GemmaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutput: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # Input embeddings + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # Cache position: only used for mask construction. + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + + # Postional ids. + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # Regular Attention mask. + if attention_mask is None: + attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id) + + # Attention masks + if not isinstance(self_attn_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": None, + } + # Create the masks + self_attn_mask_mapping = { + "full_attention": create_causal_mask( + **mask_kwargs, + or_mask_function=bidirectional_mask_function(attention_mask), + ), + "sliding_attention": create_sliding_window_causal_mask( + **mask_kwargs, + or_mask_function=sliding_window_bidirectional_mask_function(self.config.sliding_window), + and_mask_function=bidirectional_mask_function(attention_mask), + ), + } + + # embed positions + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # normalized + # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer + + # transformer layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + hidden_states = self.dropout(hidden_states) + + for layer_module in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + position_embeddings, + self_attn_mask_mapping[layer_module.attention_type], + position_ids, + output_attentions, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class T5GemmaDecoder(T5GemmaEncoder): + def __init__(self, config): + super().__init__(config) + + self.layers = nn.ModuleList( + [T5GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPastAndCrossAttentions: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if encoder_hidden_states is None: + raise ValueError("`encoder_hidden_states` must be given in decoder") + + # Input embeddings + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # Caching + if not self.training and use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache( + self_attention_cache=DynamicCache(), + cross_attention_cache=DynamicCache(), + ) + + # Cache positions. + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # Position ids. + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # Regular Attention mask. + if attention_mask is None and past_key_values is None: + attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id) + + # Attention masks: Self attention + if not isinstance(self_attn_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values.self_attention_cache if past_key_values is not None else None, + } + # Create the masks + self_attn_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + + # Attention masks: Cross attention + if not isinstance(cross_attn_mask_mapping := encoder_attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": encoder_hidden_states, + "attention_mask": encoder_attention_mask, + "cache_position": cache_position, + "past_key_values": None, + } + cross_attn_mask_mapping = { + "full_attention": create_causal_mask( + **mask_kwargs, + or_mask_function=bidirectional_mask_function(encoder_attention_mask), + ), + } + + # embed positions + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # normalized + # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer + + # transformer layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attns = () if output_attentions else None + + hidden_states = self.dropout(hidden_states) + + for layer_module in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + position_embeddings, + self_attn_mask_mapping[layer_module.attention_type], + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + encoder_hidden_states, + cross_attn_mask_mapping["full_attention"], + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + all_cross_attns += (layer_outputs[2],) + + hidden_states = self.norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attns, + ) + + +@auto_docstring +class T5GemmaModel(T5GemmaPreTrainedModel): + def __init__(self, config: T5GemmaConfig): + super().__init__(config) + + if not config.is_encoder_decoder: + raise ValueError("T5GemmaModel only support encoder-decoder modeling. Use `T5GemmaEncoderModel` instead.") + + self.encoder = T5GemmaEncoder(config.encoder) + self.decoder = T5GemmaDecoder(config.decoder) + + self.post_init() + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def get_input_embeddings(self): + return self.encoder.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.encoder.set_input_embeddings(new_embeddings) + + @can_return_tuple + @auto_docstring + def forward( + self, + # encoder + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + # decoder + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + # others + encoder_outputs: Optional[BaseModelOutput] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Seq2SeqModelOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + + **flash_attn_kwargs: flash attention related parameters. + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + **flash_attn_kwargs, + ) + + encoder_hidden_states = encoder_outputs.last_hidden_state + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **flash_attn_kwargs, + ) + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring +class T5GemmaEncoderModel(T5GemmaPreTrainedModel): + def __init__(self, config: T5GemmaConfig): + super().__init__(config) + + if config.is_encoder_decoder: + raise ValueError("T5GemmaEncoderModel only supports encoder-only model. Use `T5GemmaModel` instead.") + + self.encoder = T5GemmaEncoder(config.encoder) + self.post_init() + + def get_input_embeddings(self): + return self.encoder.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.encoder.set_input_embeddings(new_embeddings) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutput: + r""" + **flash_attn_kwargs: flash attention related parameters. + """ + + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + **flash_attn_kwargs, + ) + return encoder_outputs + + +class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["model.decoder.embed_tokens.weight", "lm_head.out_proj.weight"] + _tp_plan = {"lm_head.out_proj": "colwise_rep"} + _pp_plan = {"lm_head.out_proj": (["hidden_states"], ["logits"])} + + def __init__(self, config: T5GemmaConfig): + config.is_encoder_decoder = True + super().__init__(config) + + self.model = T5GemmaModel(config) + self.vocab_size = config.decoder.vocab_size + self.lm_head = T5GemmaLMHead(config.decoder.hidden_size, self.vocab_size) + self.loss_type = "ForMaskedLMLoss" + + self.post_init() + + def set_output_embeddings(self, new_embeddings): + self.lm_head.out_proj = new_embeddings + + def get_output_embeddings(self): + return self.lm_head.out_proj + + def _tie_weights(self): + # Decoder input and output embeddings are tied. + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.lm_head.out_proj, self.get_decoder().get_input_embeddings()) + + def get_encoder(self): + return self.model.encoder + + def get_decoder(self): + return self.model.decoder + + @can_return_tuple + @auto_docstring + def forward( + self, + # encoder + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + # decoder + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + # others + encoder_outputs: Optional[BaseModelOutput] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **loss_kwargs, + ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + if self.training and self.config._attn_implementation != "eager": + logger.warning_once( + "It is strongly recommended to train T5Gemma models with the `eager` attention implementation " + f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." + ) + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + decoder_outputs: Seq2SeqModelOutput = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **loss_kwargs, + ) + + hidden_states = decoder_outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + decoder_config = self.get_decoder().config + if decoder_config.final_logit_softcapping is not None: + logits = logits / decoder_config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * decoder_config.final_logit_softcapping + + loss = None + if labels is not None: + # Input has right-shifted so we directly perform masked lm loss + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + return Seq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.decoder_hidden_states, + decoder_attentions=decoder_outputs.decoder_attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=decoder_outputs.encoder_last_hidden_state, + encoder_hidden_states=decoder_outputs.encoder_hidden_states, + encoder_attentions=decoder_outputs.encoder_attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + +@auto_docstring +class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel): + def __init__(self, config: T5GemmaConfig, is_encoder_decoder: Optional[bool] = None): + """ + is_encoder_decoder (`Optional`, *optional*): + Whether use encoder_decoder for sequence classification. When set to False, only encoder is used. + """ + if is_encoder_decoder is not None: + config.is_encoder_decoder = is_encoder_decoder + super().__init__(config) + self.num_labels = config.num_labels + + if config.is_encoder_decoder: + self.model = T5GemmaModel(config) + else: + self.model = T5GemmaEncoderModel(config) + + hidden_size = config.encoder.hidden_size + if config.is_encoder_decoder: + hidden_size = config.decoder.hidden_size + + classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1) + self.score = T5GemmaClassificationHead(hidden_size, self.num_labels, classifier_dropout) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def forward( + self, + # encoder + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + # decoder + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + # others + encoder_outputs: Optional[BaseModelOutput] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> SequenceClassifierOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + if self.config.is_encoder_decoder and (input_ids is None and inputs_embeds is not None): + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__} in encoder-decoder mode." + ) + + # Following T5, we automatically creates decoder_input_ids from input_ids if no decoder_input_ids are provided + if self.config.is_encoder_decoder and (decoder_input_ids is None and decoder_inputs_embeds is None): + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + if self.config.is_encoder_decoder: + outputs: Seq2SeqModelOutput = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=False, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.decoder_hidden_states + attentions = outputs.decoder_attentions + else: + outputs: BaseModelOutput = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.hidden_states + attentions = outputs.attentions + + logits = self.score(last_hidden_state) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + + if self.config.is_encoder_decoder: + last_non_pad_token += 1 # due to the right shift. + last_non_pad_token = torch.clamp(last_non_pad_token, max=decoder_input_ids.shape[-1] - 1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutput( + loss=loss, + logits=pooled_logits, + hidden_states=hidden_states, + attentions=attentions, + ) + + +@auto_docstring +class T5GemmaForTokenClassification(T5GemmaPreTrainedModel): + def __init__(self, config: T5GemmaConfig, is_encoder_decoder: Optional[bool] = None): + """ + is_encoder_decoder (`Optional`, *optional*): + Whether use encoder_decoder for token classification. When set to False, only encoder is used. + """ + if is_encoder_decoder is not None: + config.is_encoder_decoder = is_encoder_decoder + super().__init__(config) + self.num_labels = config.num_labels + + if config.is_encoder_decoder: + self.model = T5GemmaModel(config) + else: + self.model = T5GemmaEncoderModel(config) + + hidden_size = config.encoder.hidden_size + if config.is_encoder_decoder: + hidden_size = config.decoder.hidden_size + + classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1) + self.score = T5GemmaClassificationHead(hidden_size, self.num_labels, classifier_dropout) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def forward( + self, + # encoder + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + # decoder + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + # others + encoder_outputs: Optional[BaseModelOutput] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> TokenClassifierOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + if self.config.is_encoder_decoder and (input_ids is None and inputs_embeds is not None): + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__} in encoder-decoder mode." + ) + + if self.config.is_encoder_decoder and (decoder_input_ids is None and decoder_inputs_embeds is None): + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + if self.config.is_encoder_decoder: + outputs: Seq2SeqModelOutput = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=False, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.decoder_hidden_states + attentions = outputs.decoder_attentions + else: + outputs: BaseModelOutput = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.hidden_states + attentions = outputs.attentions + + logits = self.score(last_hidden_state) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + attentions=attentions, + ) + + +__all__ = [ + "T5GemmaForConditionalGeneration", + "T5GemmaModel", + "T5GemmaEncoderModel", + "T5GemmaPreTrainedModel", + "T5GemmaForSequenceClassification", + "T5GemmaForTokenClassification", +] diff --git a/src/transformers/models/t5gemma/modular_t5gemma.py b/src/transformers/models/t5gemma/modular_t5gemma.py new file mode 100644 index 0000000000..aea5f3f749 --- /dev/null +++ b/src/transformers/models/t5gemma/modular_t5gemma.py @@ -0,0 +1,1455 @@ +# coding=utf-8 +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn as nn + +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...configuration_utils import PretrainedConfig +from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import ( + auto_docstring, + can_return_tuple, + is_torch_flex_attn_available, + logging, +) +from ..gemma2.configuration_gemma2 import Gemma2Config +from ..gemma2.modeling_gemma2 import ( + Gemma2Attention, + Gemma2MLP, + Gemma2PreTrainedModel, + Gemma2RMSNorm, + Gemma2RotaryEmbedding, + create_causal_mask, + create_sliding_window_causal_mask, + eager_attention_forward, +) + + +# TODO(bzhanggo): figure out these documentations +_CHECKPOINT_FOR_DOC = "google/t5gemma-placeholder" + + +if is_torch_flex_attn_available(): + pass + + +logger = logging.get_logger(__name__) + + +class T5GemmaModuleConfig(Gemma2Config): + """Module config (encoder or decoder): the same as Gemma2Config.""" + + def __init__(self, **super_kwargs): + super().__init__(**super_kwargs) + + +class T5GemmaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`T5GemmaModel`]. It is used to instantiate an T5Gemma + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to a hypothetical balanced Gemma2 encoder-decoder model. + e.g. [google/t5gemma-placeholder](https://huggingface.co/google/t5gemma-placeholder) + ```python + >>> from transformers import T5GemmaConfig, T5GemmaModel + >>> t5gemma_config = T5GemmaConfig.from_pretrained("google/t5gemma-placeholder") + >>> model = T5GemmaModel(t5gemma_config) + ``` + Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the + documentation from [PretrainedConfig] for more information. + Args: + encoder (`Union[T5GemmaModuleConfig, dict]`, optional, *optional*): + Configuration for the encoder. + decoder (`Union[T5GemmaModuleConfig, dict]`, optional, *optional*): + Configuration for the decoder. + is_encoder_decoder (bool, optional, *optional*, defaults to `True`): + Whether the model is used as an encoder/decoder or not. + dropout_rate (`float`, *optional*, defaults to 0.0): + The ratio for all dropout layers (following T5). + classifier_dropout_rate (`float`, *optional*, defaults to 0.0): + The dropout ratio for classifier (following T5). + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for attention. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether tie input and output embeddings. + kwargs (additional keyword arguments, optional, *optional*): + Will be passed to the PretrainedConfig base class. + """ + + model_type = "t5gemma" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + # encoder + "encoder.layers.*.self_attn.q_proj": "colwise", + "encoder.layers.*.self_attn.k_proj": "colwise", + "encoder.layers.*.self_attn.v_proj": "colwise", + "encoder.layers.*.self_attn.o_proj": "rowwise", + "encoder.layers.*.mlp.gate_proj": "colwise", + "encoder.layers.*.mlp.up_proj": "colwise", + "encoder.layers.*.mlp.down_proj": "rowwise", + # decoder + "decoder.layers.*.self_attn.q_proj": "colwise", + "decoder.layers.*.self_attn.k_proj": "colwise", + "decoder.layers.*.self_attn.v_proj": "colwise", + "decoder.layers.*.self_attn.o_proj": "rowwise", + "decoder.layers.*.cross_attn.q_proj": "colwise", + "decoder.layers.*.cross_attn.k_proj": "colwise", + "decoder.layers.*.cross_attn.v_proj": "colwise", + "decoder.layers.*.cross_attn.o_proj": "rowwise", + "decoder.layers.*.mlp.gate_proj": "colwise", + "decoder.layers.*.mlp.up_proj": "colwise", + "decoder.layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + # encoder + "encoder.embed_tokens": (["input_ids"], ["inputs_embeds"]), + "encoder.layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "encoder.norm": (["hidden_states"], ["hidden_states"]), + # decoder + "decoder.embed_tokens": (["input_ids"], ["inputs_embeds"]), + "decoder.layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "decoder.norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + encoder: Optional[Union[T5GemmaModuleConfig, dict[Any, Any]]] = None, + decoder: Optional[Union[T5GemmaModuleConfig, dict[Any, Any]]] = None, + is_encoder_decoder: bool = True, + dropout_rate: float = 0.0, + classifier_dropout_rate: float = 0.0, + attention_dropout: float = 0.0, + tie_word_embeddings: bool = True, + **kwargs, + ): + # Encoder. + if isinstance(encoder, dict): + # From preset configuration + encoder = T5GemmaModuleConfig(**encoder) + elif encoder is None: + # From scratch + encoder = T5GemmaModuleConfig() + else: + assert isinstance(encoder, T5GemmaModuleConfig), f"{type(encoder)} is not supported." + + # Decoder. + if isinstance(decoder, dict): + # From preset configuration + decoder = T5GemmaModuleConfig(**decoder) + elif decoder is None: + # From scratch + decoder = encoder + else: + assert isinstance(decoder, T5GemmaModuleConfig), f"{type(decoder)} is not supported." + + # Decouple encoder and decoder config in any case + encoder = T5GemmaModuleConfig(**encoder.to_dict()) + decoder = T5GemmaModuleConfig(**decoder.to_dict()) + + encoder.is_decoder = False + encoder.dropout_rate = dropout_rate + encoder.attention_dropout = attention_dropout + self.encoder = encoder + + decoder.is_decoder = True + decoder.use_cache = True + decoder.dropout_rate = dropout_rate + decoder.attention_dropout = attention_dropout + decoder.cross_attention_hidden_size = encoder.hidden_size + self.decoder = decoder + + for special_token_key in ["bos_token_id", "pad_token_id", "eos_token_id"]: + if special_token_key not in kwargs: + kwargs[special_token_key] = getattr(decoder, special_token_key) + + super().__init__(**kwargs) + + self.is_encoder_decoder = is_encoder_decoder + self.use_cache = kwargs.get("use_cache", decoder.use_cache) + self.initializer_range = kwargs.get("initializer_range", decoder.initializer_range) + self.dropout_rate = dropout_rate + self.attention_dropout = attention_dropout + self.classifier_dropout_rate = classifier_dropout_rate + self.tie_word_embeddings = tie_word_embeddings + + def __setattr__(self, key, value): + shared_attr_with_submodules = [ + "output_hidden_states", + "output_attentions", + "_attn_implementation", + "dropout_rate", + "attention_dropout", + ] + + if key in shared_attr_with_submodules: + setattr(self.encoder, key, value) + setattr(self.decoder, key, value) + super().__setattr__(key, value) + + def get_text_config(self, decoder=False) -> "PretrainedConfig": + # Always return self, regardless of the decoder option. + del decoder + return self + + +class T5GemmaRMSNorm(Gemma2RMSNorm): + pass + + +class T5GemmaMLP(Gemma2MLP): + def __init__(self, config): + super().__init__(config) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, x): + hidden_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x) + hidden_states = self.dropout(hidden_states) + down_proj = self.down_proj(hidden_states) + return down_proj + + +class T5GemmaRotaryEmbedding(Gemma2RotaryEmbedding): + def __init__(self, config, device=None): + super().__init__(config, device) + + +class T5GemmaSelfAttention(Gemma2Attention): + def __init__(self, config: T5GemmaModuleConfig, layer_idx: int): + super().__init__(config, layer_idx) + # Requied by flash attention: encoder selfattention is non-causal + self.is_causal = config.is_decoder + + +class T5GemmaCrossAttention(Gemma2Attention): + def __init__(self, config: T5GemmaModuleConfig, layer_idx: int): + super().__init__(config, layer_idx) + # Cross-attention only supports global attention + del self.sliding_window + + # Requied by flash attention + self.is_causal = False + + if config.cross_attention_hidden_size is None: + raise ValueError("Cross-attention needs cross_attention_hidden_size to be specified.") + + self.k_proj = nn.Linear( + config.cross_attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.cross_attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + encoder_hidden_states: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + if encoder_hidden_states is None: + raise ValueError("Encoder hidden state is required for cross attention.") + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + # [batch, q_len, -1, head_dim] => [batch, -1, q_len, head_dim] + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + + # conditions for calculating key and value states + if ( + # no cache + past_key_value is None + # cross-attention but not cached yet + or not is_updated + ): + encoder_input_shape = encoder_hidden_states.shape[:-1] + encoder_hidden_shape = (*encoder_input_shape, -1, self.head_dim) + # [batch, kv_len, -1, head_dim] => [batch, -1, kv_len, head_dim] + key_states = self.k_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2) + value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2) + + # update cache + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + past_key_value.is_updated[self.layer_idx] = True + # cross-attention: reuse cached states + else: + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=None, + softcap=self.attn_logit_softcapping, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +def bidirectional_mask_function(attention_mask: Optional[torch.Tensor]) -> Callable: + """ + This creates bidirectional attention mask. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + # if attention mask is not given, all attention positions are considered valid. + if attention_mask is None: + return torch.ones((), dtype=torch.bool) + # attention_mask: [batch_size, kv_len] + return attention_mask[batch_idx, kv_idx].to(torch.bool) + + return inner_mask + + +def sliding_window_bidirectional_mask_function(sliding_window: int) -> Callable: + """ + This creates bidirectional attention mask with sliding window. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + return (q_idx - sliding_window < kv_idx) & (kv_idx < q_idx + sliding_window) + + return inner_mask + + +class T5GemmaEncoderLayer(GradientCheckpointingLayer): + """Encoder sub-layer.""" + + def __init__(self, config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.config = config + self.layer_idx = layer_idx + self.attention_type = config.layer_types[layer_idx] + + # self attention + self.self_attn = T5GemmaSelfAttention( + config=config, + layer_idx=layer_idx, + ) + self.pre_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # mlp + self.mlp = T5GemmaMLP(config) + self.pre_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # dropout + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> tuple[ + torch.FloatTensor, + Optional[tuple[torch.FloatTensor, torch.FloatTensor]], + ]: + # Self Attention + residual = hidden_states + hidden_states = self.pre_self_attn_layernorm(hidden_states) + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + # Remove all caches for encoders. + use_cache=False, + past_key_value=None, + **kwargs, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + # Mlp + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class T5GemmaDecoderLayer(T5GemmaEncoderLayer): + """Decoder sub-layer: an extra cross-attention layer.""" + + def __init__(self, config, layer_idx: int): + super().__init__(config, layer_idx) + # cross attention + self.cross_attn = T5GemmaCrossAttention(config=config, layer_idx=layer_idx) + self.pre_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[EncoderDecoderCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[ + torch.FloatTensor, + Optional[tuple[torch.FloatTensor, torch.FloatTensor]], + Optional[tuple[torch.FloatTensor, torch.FloatTensor]], + ]: + # Self Attention + residual = hidden_states + hidden_states = self.pre_self_attn_layernorm(hidden_states) + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value.self_attention_cache if past_key_value is not None else None, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = self.post_self_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + # Cross Attention + residual = hidden_states + hidden_states = self.pre_cross_attn_layernorm(hidden_states) + hidden_states, cross_attn_weights = self.cross_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + hidden_states = self.post_cross_attn_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + # Mlp + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +class T5GemmaClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, hidden_size: int, num_labels: int, classifier_dropout_rate: float = 0.0): + super().__init__() + self.dropout = nn.Dropout(p=classifier_dropout_rate) + self.out_proj = nn.Linear(hidden_size, num_labels) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class T5GemmaLMHead(nn.Module): + """Head for language modeling (generation) tasks.""" + + def __init__(self, hidden_size: int, vocab_size: int, bias: bool = False): + super().__init__() + self.out_proj = nn.Linear(hidden_size, vocab_size, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.out_proj(hidden_states) + return logits + + +@auto_docstring +class T5GemmaPreTrainedModel(Gemma2PreTrainedModel): + config_class = T5GemmaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["T5GemmaBlock"] + + def _init_weights(self, module): + # TODO: support intialization for encoders and decoders separately(?) + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, T5GemmaRMSNorm): + module.weight.data.fill_(1.0) + elif isinstance(module, T5GemmaClassificationHead): + scale = module.out_proj.weight.shape[0] ** -0.5 + module.out_proj.weight.data.normal_(mean=0.0, std=std * scale) + if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: + module.out_proj.bias.data.zero_() + elif isinstance(module, T5GemmaLMHead): + if not self.config.tie_word_embeddings: + scale = module.out_proj.weight.shape[0] ** -0.5 + module.out_proj.weight.data.normal_(mean=0.0, std=std * scale) + + def _shift_right(self, input_ids): + """ + Shifts input_ids to the right, prepends the decoder_start_token_id, and handles + pad_token_id replacement for labels that were -100. + This is a common preparation step for decoder inputs in sequence-to-sequence models. + """ + decoder_start_token_id = self.config.decoder.bos_token_id + pad_token_id = self.config.decoder.pad_token_id + + if decoder_start_token_id is None: + raise ValueError("self.model.config.decoder.bos_token_id has to be defined. ") + + # shift inputs to the right + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.decoder.pad_token_id has to be defined.") + + # Is this T5 specific? + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def make_default_2d_attention_mask( + token_ids: Optional[torch.LongTensor], + hidden_states: torch.Tensor, + pad_token_id: Optional[int], +) -> torch.Tensor: + """Construct the default attention mask.""" + if token_ids is not None: + if pad_token_id is None: + raise ValueError("`pad_token_id` is required for padding information.") + attention_mask = (token_ids != pad_token_id).to(hidden_states.device, torch.long) + else: + attention_mask = torch.ones( + (hidden_states.shape[0], hidden_states.shape[1]), device=hidden_states.device, dtype=torch.long + ) + return attention_mask + + +class T5GemmaEncoder(T5GemmaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.norm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = T5GemmaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + self.layers = nn.ModuleList( + [T5GemmaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutput: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # Input embeddings + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # Cache position: only used for mask construction. + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + + # Postional ids. + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # Regular Attention mask. + if attention_mask is None: + attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id) + + # Attention masks + if not isinstance(self_attn_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": None, + } + # Create the masks + self_attn_mask_mapping = { + "full_attention": create_causal_mask( + **mask_kwargs, + or_mask_function=bidirectional_mask_function(attention_mask), + ), + "sliding_attention": create_sliding_window_causal_mask( + **mask_kwargs, + or_mask_function=sliding_window_bidirectional_mask_function(self.config.sliding_window), + and_mask_function=bidirectional_mask_function(attention_mask), + ), + } + + # embed positions + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # normalized + # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer + + # transformer layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + hidden_states = self.dropout(hidden_states) + + for layer_module in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + position_embeddings, + self_attn_mask_mapping[layer_module.attention_type], + position_ids, + output_attentions, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class T5GemmaDecoder(T5GemmaEncoder): + def __init__(self, config): + super().__init__(config) + + self.layers = nn.ModuleList( + [T5GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPastAndCrossAttentions: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if encoder_hidden_states is None: + raise ValueError("`encoder_hidden_states` must be given in decoder") + + # Input embeddings + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # Caching + if not self.training and use_cache and past_key_values is None: + past_key_values = EncoderDecoderCache( + self_attention_cache=DynamicCache(), + cross_attention_cache=DynamicCache(), + ) + + # Cache positions. + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # Position ids. + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # Regular Attention mask. + if attention_mask is None and past_key_values is None: + attention_mask = make_default_2d_attention_mask(input_ids, inputs_embeds, self.config.pad_token_id) + + # Attention masks: Self attention + if not isinstance(self_attn_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values.self_attention_cache if past_key_values is not None else None, + } + # Create the masks + self_attn_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), + } + + # Attention masks: Cross attention + if not isinstance(cross_attn_mask_mapping := encoder_attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": encoder_hidden_states, + "attention_mask": encoder_attention_mask, + "cache_position": cache_position, + "past_key_values": None, + } + cross_attn_mask_mapping = { + "full_attention": create_causal_mask( + **mask_kwargs, + or_mask_function=bidirectional_mask_function(encoder_attention_mask), + ), + } + + # embed positions + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # normalized + # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + hidden_states = hidden_states * normalizer + + # transformer layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attns = () if output_attentions else None + + hidden_states = self.dropout(hidden_states) + + for layer_module in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + position_embeddings, + self_attn_mask_mapping[layer_module.attention_type], + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + encoder_hidden_states, + cross_attn_mask_mapping["full_attention"], + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + all_cross_attns += (layer_outputs[2],) + + hidden_states = self.norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attns, + ) + + +@auto_docstring +class T5GemmaModel(T5GemmaPreTrainedModel): + def __init__(self, config: T5GemmaConfig): + super().__init__(config) + + if not config.is_encoder_decoder: + raise ValueError("T5GemmaModel only support encoder-decoder modeling. Use `T5GemmaEncoderModel` instead.") + + self.encoder = T5GemmaEncoder(config.encoder) + self.decoder = T5GemmaDecoder(config.decoder) + + self.post_init() + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def get_input_embeddings(self): + return self.encoder.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.encoder.set_input_embeddings(new_embeddings) + + @can_return_tuple + @auto_docstring + def forward( + self, + # encoder + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + # decoder + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + # others + encoder_outputs: Optional[BaseModelOutput] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Seq2SeqModelOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + + **flash_attn_kwargs: flash attention related parameters. + """ + use_cache = use_cache if use_cache is not None else self.config.use_cache + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + **flash_attn_kwargs, + ) + + encoder_hidden_states = encoder_outputs.last_hidden_state + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **flash_attn_kwargs, + ) + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring +class T5GemmaEncoderModel(T5GemmaPreTrainedModel): + def __init__(self, config: T5GemmaConfig): + super().__init__(config) + + if config.is_encoder_decoder: + raise ValueError("T5GemmaEncoderModel only supports encoder-only model. Use `T5GemmaModel` instead.") + + self.encoder = T5GemmaEncoder(config.encoder) + self.post_init() + + def get_input_embeddings(self): + return self.encoder.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.encoder.set_input_embeddings(new_embeddings) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutput: + r""" + **flash_attn_kwargs: flash attention related parameters. + """ + + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + **flash_attn_kwargs, + ) + return encoder_outputs + + +class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["model.decoder.embed_tokens.weight", "lm_head.out_proj.weight"] + _tp_plan = {"lm_head.out_proj": "colwise_rep"} + _pp_plan = {"lm_head.out_proj": (["hidden_states"], ["logits"])} + + def __init__(self, config: T5GemmaConfig): + config.is_encoder_decoder = True + super().__init__(config) + + self.model = T5GemmaModel(config) + self.vocab_size = config.decoder.vocab_size + self.lm_head = T5GemmaLMHead(config.decoder.hidden_size, self.vocab_size) + self.loss_type = "ForMaskedLMLoss" + + self.post_init() + + def set_output_embeddings(self, new_embeddings): + self.lm_head.out_proj = new_embeddings + + def get_output_embeddings(self): + return self.lm_head.out_proj + + def _tie_weights(self): + # Decoder input and output embeddings are tied. + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.lm_head.out_proj, self.get_decoder().get_input_embeddings()) + + def get_encoder(self): + return self.model.encoder + + def get_decoder(self): + return self.model.decoder + + @can_return_tuple + @auto_docstring + def forward( + self, + # encoder + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + # decoder + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + # others + encoder_outputs: Optional[BaseModelOutput] = None, + past_key_values: Optional[EncoderDecoderCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **loss_kwargs, + ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + if self.training and self.config._attn_implementation != "eager": + logger.warning_once( + "It is strongly recommended to train T5Gemma models with the `eager` attention implementation " + f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." + ) + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + decoder_outputs: Seq2SeqModelOutput = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **loss_kwargs, + ) + + hidden_states = decoder_outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + decoder_config = self.get_decoder().config + if decoder_config.final_logit_softcapping is not None: + logits = logits / decoder_config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * decoder_config.final_logit_softcapping + + loss = None + if labels is not None: + # Input has right-shifted so we directly perform masked lm loss + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + return Seq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.decoder_hidden_states, + decoder_attentions=decoder_outputs.decoder_attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=decoder_outputs.encoder_last_hidden_state, + encoder_hidden_states=decoder_outputs.encoder_hidden_states, + encoder_attentions=decoder_outputs.encoder_attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + +@auto_docstring +class T5GemmaForSequenceClassification(T5GemmaPreTrainedModel): + def __init__(self, config: T5GemmaConfig, is_encoder_decoder: Optional[bool] = None): + """ + is_encoder_decoder (`Optional`, *optional*): + Whether use encoder_decoder for sequence classification. When set to False, only encoder is used. + """ + if is_encoder_decoder is not None: + config.is_encoder_decoder = is_encoder_decoder + super().__init__(config) + self.num_labels = config.num_labels + + if config.is_encoder_decoder: + self.model = T5GemmaModel(config) + else: + self.model = T5GemmaEncoderModel(config) + + hidden_size = config.encoder.hidden_size + if config.is_encoder_decoder: + hidden_size = config.decoder.hidden_size + + classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1) + self.score = T5GemmaClassificationHead(hidden_size, self.num_labels, classifier_dropout) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def forward( + self, + # encoder + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + # decoder + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + # others + encoder_outputs: Optional[BaseModelOutput] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> SequenceClassifierOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + if self.config.is_encoder_decoder and (input_ids is None and inputs_embeds is not None): + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__} in encoder-decoder mode." + ) + + # Following T5, we automatically creates decoder_input_ids from input_ids if no decoder_input_ids are provided + if self.config.is_encoder_decoder and (decoder_input_ids is None and decoder_inputs_embeds is None): + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + if self.config.is_encoder_decoder: + outputs: Seq2SeqModelOutput = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=False, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.decoder_hidden_states + attentions = outputs.decoder_attentions + else: + outputs: BaseModelOutput = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.hidden_states + attentions = outputs.attentions + + logits = self.score(last_hidden_state) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + + if self.config.is_encoder_decoder: + last_non_pad_token += 1 # due to the right shift. + last_non_pad_token = torch.clamp(last_non_pad_token, max=decoder_input_ids.shape[-1] - 1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutput( + loss=loss, + logits=pooled_logits, + hidden_states=hidden_states, + attentions=attentions, + ) + + +@auto_docstring +class T5GemmaForTokenClassification(T5GemmaPreTrainedModel): + def __init__(self, config: T5GemmaConfig, is_encoder_decoder: Optional[bool] = None): + """ + is_encoder_decoder (`Optional`, *optional*): + Whether use encoder_decoder for token classification. When set to False, only encoder is used. + """ + if is_encoder_decoder is not None: + config.is_encoder_decoder = is_encoder_decoder + super().__init__(config) + self.num_labels = config.num_labels + + if config.is_encoder_decoder: + self.model = T5GemmaModel(config) + else: + self.model = T5GemmaEncoderModel(config) + + hidden_size = config.encoder.hidden_size + if config.is_encoder_decoder: + hidden_size = config.decoder.hidden_size + + classifier_dropout = getattr(config, "classifier_dropout_rate", 0.1) + self.score = T5GemmaClassificationHead(hidden_size, self.num_labels, classifier_dropout) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring + def forward( + self, + # encoder + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + # decoder + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + decoder_position_ids: Optional[torch.LongTensor] = None, + # others + encoder_outputs: Optional[BaseModelOutput] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> TokenClassifierOutput: + r""" + decoder_position_ids (`torch.LongTensor` of shape `(batch_size, decoder_sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, + config.decoder.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + if self.config.is_encoder_decoder and (input_ids is None and inputs_embeds is not None): + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__} in encoder-decoder mode." + ) + + if self.config.is_encoder_decoder and (decoder_input_ids is None and decoder_inputs_embeds is None): + if input_ids is None: + raise ValueError( + "If no `decoder_input_ids` or `decoder_inputs_embeds` are " + "passed, `input_ids` cannot be `None`. Please pass either " + "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." + ) + decoder_input_ids = self._shift_right(input_ids) + + if self.config.is_encoder_decoder: + outputs: Seq2SeqModelOutput = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=False, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.decoder_hidden_states + attentions = outputs.decoder_attentions + else: + outputs: BaseModelOutput = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + last_hidden_state = outputs.last_hidden_state + hidden_states = outputs.hidden_states + attentions = outputs.attentions + + logits = self.score(last_hidden_state) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=hidden_states, + attentions=attentions, + ) + + +__all__ = [ + "T5GemmaConfig", + "T5GemmaModuleConfig", + "T5GemmaForConditionalGeneration", + "T5GemmaModel", + "T5GemmaEncoderModel", + "T5GemmaPreTrainedModel", # noqa: F822 + "T5GemmaForSequenceClassification", + "T5GemmaForTokenClassification", +] diff --git a/tests/models/t5gemma/__init__.py b/tests/models/t5gemma/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/t5gemma/test_modeling_t5gemma.py b/tests/models/t5gemma/test_modeling_t5gemma.py new file mode 100644 index 0000000000..ba49e91330 --- /dev/null +++ b/tests/models/t5gemma/test_modeling_t5gemma.py @@ -0,0 +1,1701 @@ +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch T5Gemma model.""" + +import copy +import inspect +import unittest + +import pytest +from parameterized import parameterized + +from transformers import T5GemmaConfig, T5GemmaModuleConfig, is_torch_available +from transformers.testing_utils import ( + require_torch, + require_torch_accelerator, + require_torch_gpu, + require_torch_sdpa, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + import torch.nn.functional as F + + from transformers import ( + T5GemmaEncoderModel, + T5GemmaForConditionalGeneration, + T5GemmaForSequenceClassification, + T5GemmaForTokenClassification, + T5GemmaModel, + ) + from transformers.cache_utils import Cache + + +class T5GemmaModelTester: + config_class = T5GemmaConfig + module_config_class = T5GemmaModuleConfig + + if is_torch_available(): + model_class = T5GemmaModel + for_causal_lm_class = T5GemmaForConditionalGeneration + for_sequence_class = T5GemmaForSequenceClassification + for_token_class = T5GemmaForTokenClassification + + def __init__( + self, + parent, + batch_size=13, + is_training=True, + use_attention_mask=True, + use_labels=True, + vocab_size=99, + # decoder-specific + seq_length=7, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=37, + # encoder-specific + encoder_seq_length=7, + encoder_hidden_size=32, + encoder_num_hidden_layers=2, + encoder_num_attention_heads=4, + encoder_num_key_value_heads=2, + encoder_intermediate_size=37, + # common + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + scope=None, + # special ids + eos_token_id=1, + pad_token_id=0, + bos_token_id=2, + ): + self.parent = parent + self.batch_size = batch_size + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + # decoder + self.seq_length = seq_length + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.intermediate_size = intermediate_size + # encoder + self.encoder_seq_length = encoder_seq_length + self.encoder_hidden_size = encoder_hidden_size + self.encoder_num_hidden_layers = encoder_num_hidden_layers + self.encoder_num_attention_heads = encoder_num_attention_heads + self.encoder_num_key_value_heads = encoder_num_key_value_heads + self.encoder_intermediate_size = encoder_intermediate_size + # common + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + self.head_dim = self.hidden_size // self.num_attention_heads + # assume encoder and decoder have the same head dimension. + assert self.head_dim == self.encoder_hidden_size // self.encoder_num_attention_heads + # special ids + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + # assume the number of attention heads are the same across encoder and decoder + # only used for generation testing purpose. + assert self.num_attention_heads == self.encoder_num_attention_heads + + def get_encoder_config(self): + return self.module_config_class( + vocab_size=self.vocab_size, + hidden_size=self.encoder_hidden_size, + num_hidden_layers=self.encoder_num_hidden_layers, + num_attention_heads=self.encoder_num_attention_heads, + num_key_value_heads=self.encoder_num_key_value_heads, + intermediate_size=self.encoder_intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + head_dim=self.head_dim, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + ) + + def get_decoder_config(self): + return self.module_config_class( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + intermediate_size=self.intermediate_size, + cross_attention_hidden_size=self.encoder_hidden_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=True, + initializer_range=self.initializer_range, + head_dim=self.head_dim, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + ) + + def get_config(self, is_encoder_decoder=True): + return self.config_class( + encoder=self.get_encoder_config(), + decoder=self.get_decoder_config(), + is_encoder_decoder=is_encoder_decoder, + # Used for generation test. + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + ) + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) + decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + # Remove BOS symbols from inputs. + input_ids = torch.where(input_ids == self.bos_token_id, 42, input_ids) + decoder_input_ids = torch.where(decoder_input_ids == self.bos_token_id, 42, decoder_input_ids) + + attention_mask = None + decoder_attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) + decoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + lm_labels = None + if self.use_labels: + lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + config = self.get_config() + + return ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) + + # Copied from tests.models.t5.test_modeling_t5.T5ModelTester.prepare_config_and_inputs_for_common + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + } + return config, inputs_dict + + def create_and_check_model( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = self.model_class(config=config).to(torch_device).eval() + + result = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + + decoder_output = result.last_hidden_state + decoder_past = result.past_key_values + encoder_output = result.encoder_last_hidden_state + + self.parent.assertEqual( + encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.encoder_hidden_size) + ) + self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertIsNotNone(decoder_past) + self.parent.assertEqual(len(decoder_past.self_attention_cache), config.decoder.num_hidden_layers) + self.parent.assertEqual(len(decoder_past.cross_attention_cache.key_cache), config.decoder.num_hidden_layers) + + def check_prepare_lm_labels_via_shift_left( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = self.model_class(config=config).to(torch_device).eval() + + # _shift_right should be called on labels + shifted_labels = model._shift_right(lm_labels) + + # first token should be decoder_start_token_id + self.parent.assertTrue(torch.all(shifted_labels[:, 0] == config.decoder.bos_token_id)) + + # the rest should be the labels shifted by one, with -100 replaced by pad_token_id + labels_without_ignore_index = lm_labels.masked_fill(lm_labels == -100, config.decoder.pad_token_id) + self.parent.assertTrue(torch.all(shifted_labels[:, 1:] == labels_without_ignore_index[:, :-1])) + + def create_and_check_with_lm_head( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = self.for_causal_lm_class(config=config).to(torch_device).eval() + outputs = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + labels=lm_labels, + ) + self.parent.assertEqual(len(outputs), 4) + self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.seq_length, self.vocab_size)) + self.parent.assertEqual(outputs["loss"].size(), ()) + + def create_and_check_with_sequence_classification_head( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + labels = torch.tensor([1] * self.batch_size, dtype=torch.long, device=torch_device) + model = self.for_sequence_class(config=config).to(torch_device).eval() + outputs = model( + input_ids=input_ids, + decoder_input_ids=input_ids, + labels=labels, + ) + self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, config.num_labels)) + self.parent.assertEqual(outputs["loss"].size(), ()) + + def create_and_check_encoderonly_for_sequence_classification_head( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + is_encoder_decoder, + ): + labels = torch.tensor([1] * self.batch_size, dtype=torch.long, device=torch_device) + model = self.for_sequence_class(config=config, is_encoder_decoder=is_encoder_decoder) + model = model.to(torch_device).eval() + outputs = model( + input_ids=input_ids, + decoder_input_ids=input_ids, + labels=labels, + ) + + self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, config.num_labels)) + self.parent.assertEqual(outputs["loss"].size(), ()) + + def create_and_check_encoderonly_for_token_classification_head( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + is_encoder_decoder, + ): + labels = torch.tensor([1] * self.seq_length * self.batch_size, dtype=torch.long, device=torch_device) + model = self.for_token_class(config=config, is_encoder_decoder=is_encoder_decoder) + model = model.to(torch_device).eval() + outputs = model( + input_ids=input_ids, + decoder_input_ids=input_ids, + labels=labels, + ) + + self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.seq_length, config.num_labels)) + self.parent.assertEqual(outputs["loss"].size(), ()) + + def create_and_check_decoder_model_past( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = self.model_class(config=config).get_decoder().to(torch_device).eval() + encoder_hidden_states = torch.ones( + (self.batch_size, self.encoder_seq_length, self.encoder_hidden_size), dtype=torch.float32 + ).to(torch_device) + + # first forward pass + outputs = model(input_ids, encoder_hidden_states=encoder_hidden_states, use_cache=True) + outputs_use_cache_conf = model(input_ids, encoder_hidden_states=encoder_hidden_states) + outputs_no_past = model(input_ids, encoder_hidden_states=encoder_hidden_states, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + output, past_key_values = outputs.to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + + output_from_no_past = model(next_input_ids, encoder_hidden_states=encoder_hidden_states)["last_hidden_state"] + output_from_past = model( + next_tokens, encoder_hidden_states=encoder_hidden_states, past_key_values=past_key_values + )["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_decoder_model_attention_mask_past( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = self.model_class(config=config).get_decoder().to(torch_device).eval() + encoder_hidden_states = torch.ones( + (self.batch_size, self.encoder_seq_length, self.encoder_hidden_size), dtype=torch.float32 + ).to(torch_device) + + # create attention mask + attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + half_seq_length = input_ids.shape[-1] // 2 + attn_mask[:, half_seq_length:] = 0 + + # first forward pass + output, past_key_values = model( + input_ids, encoder_hidden_states=encoder_hidden_states, attention_mask=attn_mask, use_cache=True + ).to_tuple() + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) + input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens + + # append to next input_ids and attn_mask + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], + dim=1, + ) + + # get two different outputs + output_from_no_past = model( + next_input_ids, encoder_hidden_states=encoder_hidden_states, attention_mask=attn_mask + )["last_hidden_state"] + output_from_past = model( + next_tokens, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + attention_mask=attn_mask, + )["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = self.model_class(config=config).get_decoder().to(torch_device).eval() + encoder_hidden_states = torch.ones( + (self.batch_size, self.encoder_seq_length, self.encoder_hidden_size), dtype=torch.float32 + ).to(torch_device) + + # first forward pass + outputs = model( + input_ids, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, use_cache=True + ) + + output, past_key_values = outputs.to_tuple() + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([attention_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, encoder_hidden_states=encoder_hidden_states, attention_mask=next_attention_mask + )["last_hidden_state"] + output_from_past = model( + next_tokens, + encoder_hidden_states=encoder_hidden_states, + attention_mask=next_attention_mask, + past_key_values=past_key_values, + )["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_generate_with_past_key_values( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = self.for_causal_lm_class(config=config).to(torch_device).eval() + torch.manual_seed(0) + output_without_past_cache = model.generate( + input_ids[:1], num_beams=2, max_length=5, do_sample=True, use_cache=False + ) + torch.manual_seed(0) + output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True) + self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache)) + + def create_and_check_model_fp16_forward( + self, + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = self.model_class(config=config).to(torch_device).half().eval() + output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"] + self.parent.assertFalse(torch.isnan(output).any().item()) + + +@require_torch +class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + T5GemmaModel, + T5GemmaForConditionalGeneration, + T5GemmaForSequenceClassification, + T5GemmaForTokenClassification, + ) + if is_torch_available() + else () + ) + pipeline_model_mapping = ( + { + "feature-extraction": T5GemmaModel, + "summarization": T5GemmaForConditionalGeneration, + "text-classification": T5GemmaForSequenceClassification, + "text2text-generation": T5GemmaForConditionalGeneration, + "translation": T5GemmaForConditionalGeneration, + "zero-shot": T5GemmaForSequenceClassification, + } + if is_torch_available() + else {} + ) + + test_headmasking = False + test_pruning = False + _is_stateful = True + is_encoder_decoder = True + model_split_percents = [0.5, 0.6] + + # used in `test_torch_compile_for_training` + _torch_compile_train_cls = T5GemmaForConditionalGeneration if is_torch_available() else None + + def setUp(self): + self.model_tester = T5GemmaModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=T5GemmaConfig, + # For faking the testing. + hidden_size=37, + vocab_size=self.model_tester.vocab_size, + num_attention_heads=self.model_tester.num_attention_heads, + num_hidden_layers=self.model_tester.num_hidden_layers, + ) + + # Copied from tests.models.t5.test_modeling_t5.T5ModelTest.is_pipeline_test_to_skip + def is_pipeline_test_to_skip( + self, + pipeline_test_case_name, + config_class, + model_architecture, + tokenizer_name, + image_processor_name, + feature_extractor_name, + processor_name, + ): + if tokenizer_name is None: + return True + if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"): + return True + + return False + + # Copied from tests.models.t5.test_modeling_t5.T5ModelTest.test_config + def test_config(self): + self.config_tester.run_common_tests() + + # Copied from tests.models.t5.test_modeling_t5.T5ModelTest.test_shift_right + def test_shift_right(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_prepare_lm_labels_via_shift_left(*config_and_inputs) + + # Copied from tests.models.t5.test_modeling_t5.T5ModelTest.test_model + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + # Based on tests.models.t5.test_modeling_t5.T5ModelTest.test_inputs_embeds + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in (T5GemmaModel, T5GemmaForConditionalGeneration): + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + + if not self.is_encoder_decoder: + input_ids = inputs["input_ids"] + del inputs["input_ids"] + else: + encoder_input_ids = inputs["input_ids"] + decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) + del inputs["input_ids"] + inputs.pop("decoder_input_ids", None) + + wte = model.get_input_embeddings() + if not self.is_encoder_decoder: + inputs["inputs_embeds"] = wte(input_ids) + else: + inputs["inputs_embeds"] = wte(encoder_input_ids) + inputs["decoder_inputs_embeds"] = wte(decoder_input_ids) + + with torch.no_grad(): + model(**inputs)[0] + + # Copied from tests.models.t5.test_modeling_t5.T5ModelTest.test_config_and_model_silu_gated + def test_config_and_model_silu_gated(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + config = config_and_inputs[0] + config.feed_forward_proj = "gated-silu" + self.model_tester.create_and_check_model(*config_and_inputs) + + # Copied from tests.models.t5.test_modeling_t5.T5ModelTest.test_with_lm_head + def test_with_lm_head(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_with_lm_head(*config_and_inputs) + + # Copied from tests.models.t5.test_modeling_t5.T5ModelTest.test_with_sequence_classification_head + def test_with_sequence_classification_head(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_with_sequence_classification_head(*config_and_inputs) + + @parameterized.expand([(True,), (False,)]) + def test_encoderonly_sequence_classification_head(self, is_encoder_decoder): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_encoderonly_for_sequence_classification_head( + *config_and_inputs, is_encoder_decoder + ) + + @parameterized.expand([(True,), (False,)]) + def test_encoderonly_token_classification_head(self, is_encoder_decoder): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_encoderonly_for_token_classification_head( + *config_and_inputs, is_encoder_decoder + ) + + # Copied from tests.models.t5.test_modeling_t5.T5ModelTest.test_decoder_model_past + def test_decoder_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past(*config_and_inputs) + + # Copied from tests.models.t5.test_modeling_t5.T5ModelTest.test_decoder_model_past_with_attn_mask + def test_decoder_model_past_with_attn_mask(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) + + # Based on tests.models.t5.test_modeling_t5.T5ModelTest.test_decoder_model_past_with_3d_attn_mask + def test_decoder_model_past_with_3d_attn_mask(self): + ( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) = self.model_tester.prepare_config_and_inputs() + + attention_mask = ids_tensor( + [self.model_tester.batch_size, self.model_tester.encoder_seq_length, self.model_tester.encoder_seq_length], + vocab_size=2, + ) + decoder_attention_mask = ids_tensor( + [self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.seq_length], + vocab_size=2, + ) + + self.model_tester.create_and_check_decoder_model_attention_mask_past( + config, + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) + + # Copied from tests.models.t5.test_modeling_t5.T5ModelTest.test_decoder_model_past_with_large_inputs + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + # Copied from tests.models.t5.test_modeling_t5.T5ModelTest.test_generate_with_past_key_values + def test_generate_with_past_key_values(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_generate_with_past_key_values(*config_and_inputs) + + @unittest.skipIf(torch_device == "cpu", "Can't do half precision") + # Copied from tests.models.t5.test_modeling_t5.T5ModelTest.test_model_fp16_forward + def test_model_fp16_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) + + # Based on tests.models.gemma.test_modeling_gemma.GemmaModelTest.test_Gemma_sequence_classification_model with Gemma -> T5Gemma (Add is_encoder_decoder option) + def test_T5Gemma_sequence_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + + for is_encoder_decoder in [True, False]: + model = ( + self.model_tester.for_sequence_class(config, is_encoder_decoder=is_encoder_decoder) + .to(torch_device) + .eval() + ) + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + # Based on tests.models.gemma.test_modeling_gemma.GemmaModelTest.test_Gemma_sequence_classification_model_for_single_label with Gemma -> T5Gemma (Add is_encoder_decoder option) + def test_T5Gemma_sequence_classification_model_for_single_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "single_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + + for is_encoder_decoder in [True, False]: + model = ( + self.model_tester.for_sequence_class(config, is_encoder_decoder=is_encoder_decoder) + .to(torch_device) + .eval() + ) + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + # Based on tests.models.gemma.test_modeling_gemma.GemmaModelTest.test_Gemma_sequence_classification_model_for_multi_label with Gemma -> T5Gemma (Add is_encoder_decoder option) + def test_T5Gemma_sequence_classification_model_for_multi_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "multi_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor( + [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size + ).to(torch.float) + + for is_encoder_decoder in [True, False]: + model = ( + self.model_tester.for_sequence_class(config, is_encoder_decoder=is_encoder_decoder) + .to(torch_device) + .eval() + ) + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + # Based on tests.models.gemma.test_modeling_gemma.GemmaModelTest.test_Gemma_token_classification_model with Gemma -> T5Gemma (Add is_encoder_decoder option) + def test_T5Gemma_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + + for is_encoder_decoder in [True, False]: + model = ( + self.model_tester.for_token_class(config, is_encoder_decoder=is_encoder_decoder) + .to(torch_device) + .eval() + ) + + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + + # Based on tests.models.gemma.test_modeling_gemma.GemmaModelTest.test_sdpa_equivalence + # Add decoder_input_ids and adjust hidden states. + @require_torch_sdpa + @require_torch_accelerator + def test_sdpa_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_sdpa: + self.skipTest(reason="Model does not support SDPA") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(torch_device) + dummy_input = inputs_dict[model_class.main_input_name].to(torch_device) + decoder_dummy_input = torch.ones_like(dummy_input) + + model.config._attn_implementation = "sdpa" + states_sdpa = model(dummy_input, decoder_input_ids=decoder_dummy_input, output_hidden_states=True) + + model.config._attn_implementation = "eager" + states_eager = model(dummy_input, decoder_input_ids=decoder_dummy_input, output_hidden_states=True) + + if hasattr(states_sdpa, "decoder_hidden_states"): + states_sdpa = states_sdpa.decoder_hidden_states[-1] + states_eager = states_eager.decoder_hidden_states[-1] + else: + states_sdpa = states_sdpa.hidden_states[-1] + states_eager = states_eager.hidden_states[-1] + + torch.testing.assert_close(states_sdpa, states_eager, atol=1e-5, rtol=1e-5) + + @unittest.skip("T5Gemma eager/FA2 attention outputs are expected to be different") + def test_flash_attn_2_equivalence(self): + pass + + # Based on tests.test_modeling_common.ModelTesterMixin.test_attention_outputs + # Skip token classification + def test_attention_outputs(self): + if not self.has_attentions: + self.skipTest(reason="Model does not output attentions") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + # force eager attention to support output attentions + config._attn_implementation = "eager" + + seq_len = getattr(self.model_tester, "seq_length", None) + decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len) + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) + decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + chunk_length = getattr(self.model_tester, "chunk_length", None) + if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): + encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes + + for model_class in self.all_model_classes: + # Skip token and sequence classification. + if model_class in [self.model_tester.for_token_class, self.model_tester.for_sequence_class]: + continue + + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + model = model_class._from_config(config, attn_implementation="eager") + config = model.config + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config._attn_implementation = "eager" + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + if chunk_length is not None: + self.assertListEqual( + list(attentions[0].shape[-4:]), + [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], + ) + else: + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + out_len = len(outputs) + + if self.is_encoder_decoder: + correct_outlen = 5 + + # loss is at first position + if "labels" in inputs_dict: + correct_outlen += 1 # loss is added to beginning + if "past_key_values" in outputs: + correct_outlen += 1 # past_key_values have been returned + + self.assertEqual(out_len, correct_outlen) + + # decoder attentions + decoder_attentions = outputs.decoder_attentions + self.assertIsInstance(decoder_attentions, (list, tuple)) + self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(decoder_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length], + ) + + # cross attentions + cross_attentions = outputs.cross_attentions + self.assertIsInstance(cross_attentions, (list, tuple)) + self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(cross_attentions[0].shape[-3:]), + [ + self.model_tester.num_attention_heads, + decoder_seq_length, + encoder_key_length, + ], + ) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + elif self.is_encoder_decoder: + added_hidden_states = 2 + else: + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + if chunk_length is not None: + self.assertListEqual( + list(self_attentions[0].shape[-4:]), + [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], + ) + else: + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + + # Based on tests.generation.test_utils.GenerationTesterMixin.test_past_key_values_format + # Adjust encoder attention number for cross-attention caching and update attention head dimension + @pytest.mark.generate + def test_past_key_values_format(self, custom_all_cache_shapes=None): + """ + Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test, or pass the + expected cache shapes. + Having a standard KV cache format is important for a consistent API (and for advanced generation methods). + """ + for model_class in self.all_generative_model_classes: + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + # 1. If it doesn't support cache, skip the test + if not hasattr(config.get_text_config(), "use_cache"): + self.skipTest(reason=f"{model_class.__name__} doesn't support caching") + + model = model_class(config).to(torch_device) + model = model.eval() + if "use_cache" not in inputs: + inputs["use_cache"] = True + outputs = model(**inputs) + + if "past_key_values" not in outputs: + self.skipTest(reason="This model doesn't return `past_key_values`") + + # 2. retrieve the KV cache and compute its default expected shapes (if no custom shapes are provided) + past_kv = outputs["past_key_values"] + is_legacy_cache = not isinstance(past_kv, Cache) + + text_config = config.get_text_config().decoder + num_decoder_layers = text_config.num_hidden_layers + + if custom_all_cache_shapes is None: + num_query_attention_heads = getattr( + text_config, "decoder_attention_heads", text_config.num_attention_heads + ) + per_head_embed_dim = text_config.head_dim + num_key_value_heads = ( + text_config.num_key_value_heads + if getattr(text_config, "num_key_value_heads", None) is not None + else num_query_attention_heads + ) + if config.is_encoder_decoder: + encoder_num_attention_heads = num_key_value_heads + encoder_per_head_embed_dim = per_head_embed_dim + batch_size, seq_length = inputs["decoder_input_ids"].shape[:2] + # The sequence length for the encoder K V depends on the model. Since it is not manipulated in + # autoregressive generation, we're keeping the test general and not checking the 3rd dim + default_cross_attention_shape = ( + batch_size, + encoder_num_attention_heads, + encoder_per_head_embed_dim, + ) + default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim) + all_cache_shapes = [ + [ + default_self_attention_shape, + default_self_attention_shape, + default_cross_attention_shape, + default_cross_attention_shape, + ] + for _ in range(num_decoder_layers) + ] + else: + batch_size, seq_length = inputs["input_ids"].shape[:2] + default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim) + all_cache_shapes = [ + [default_self_attention_shape, default_self_attention_shape] for _ in range(num_decoder_layers) + ] + + else: + all_cache_shapes = custom_all_cache_shapes + + # 3. Check cache shapes + # 3.1. Encoder-Decoder checks + if config.is_encoder_decoder: + num_cache_decoder_layers = ( + len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache) + ) + self.assertEqual(num_cache_decoder_layers, num_decoder_layers) + + for i in range(num_decoder_layers): + if is_legacy_cache: + self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple + + # Self attention + self_attention_layer_key_cache = ( + past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i] + ) + self_attention_layer_value_cache = ( + past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i] + ) + self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) + + # Cross attention (ignore 3rd dim, see default shape preparation) + cross_attention_layer_key_cache = ( + past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i] + ) + cross_attention_layer_value_cache = ( + past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i] + ) + cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :] + cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :] + self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2]) + self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3]) + + # 3.2. Decoder-only checks + else: + num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache) + self.assertEqual(num_cache_decoder_layers, num_decoder_layers) + + for i in range(num_decoder_layers): + if is_legacy_cache: + self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple + + # Self attention + self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i] + self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i] + self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0]) + self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1]) + + @unittest.skip("Mismatch issue doesn't exist in T5Gemma.") + def test_load_with_mismatched_shapes(self): + pass + + # Based on tests.generation.test_utils.GenerationTesterMixin.test_generate_continue_from_past_key_values + # Updated decoder_attention_mask to consider the appended bos token + @pytest.mark.generate + def test_generate_continue_from_past_key_values(self): + # Tests that we can continue generating from past key values, returned from a previous `generate` call + for model_class in self.all_generative_model_classes: + if model_class == self.model_tester.for_token_class: + continue + if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt", "mllama"]): + self.skipTest(reason="Won't fix: old model with unique inputs/caches/other") + if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]): + self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility") + + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + if not hasattr(config.get_text_config(), "use_cache"): + self.skipTest(reason=f"{model_class.__name__} doesn't support caching") + + # Let's make it always: + # 1. use cache (for obvious reasons) + # 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which + # would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the + # continuation would force it to generate beyond an EOS token) + # 3. ignore `token_type_ids` for simplicity + # 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is + # active by default on some models + # 5. ignore `encoder_no_repeat_ngram_size`, which is set by default in some encoder-decoder models. When + # we use their decoder as a stand-alone model, `encoder_no_repeat_ngram_size` actually prevents + # repetition exclusively from the prompt. This test relies on comparing one call vs 2 calls + # with cache, what is considered a prompt is different in the two cases. + + if "token_type_ids" in inputs: + del inputs["token_type_ids"] + + model = model_class(config).to(torch_device) + model.eval() + + # If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format) + outputs = model(**inputs) + if "past_key_values" not in outputs: + self.skipTest(reason="This model doesn't return `past_key_values`") + + generate_kwargs = { + "pad_token_id": -1, + "eos_token_id": -1, + "forced_eos_token_id": None, + "encoder_no_repeat_ngram_size": 0, + "use_cache": True, + "do_sample": False, + "return_dict_in_generate": True, + "output_scores": True, + } + + # Traditional way of generating text, with `return_dict_in_generate` to return the past key values + outputs = model.generate(**inputs, **generate_kwargs, max_new_tokens=4) + + # Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the + # inputs may need to be tweaked across `generate` calls (like the attention mask). + outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=3) + + # Continue from the tokens generated above, preparing the inputs accordingly + inputs["past_key_values"] = outputs_cached.past_key_values + new_attention_len = outputs_cached.sequences.shape[-1] + + # It must be encoder-decoder models + self.assertTrue(config.is_encoder_decoder) + + inputs["decoder_input_ids"] = outputs_cached.sequences + if "decoder_attention_mask" in inputs: + decoder_attention_mask = inputs["decoder_attention_mask"] + + # Add BOS mask: the new sequence comes with a new BOS token, which is not included in the original inputs + padding_tensor = torch.ones_like(decoder_attention_mask[:, :1]) + decoder_attention_mask = torch.cat([padding_tensor, decoder_attention_mask], dim=1) + + inputs["decoder_attention_mask"] = torch.nn.functional.pad( + decoder_attention_mask, + (0, new_attention_len - decoder_attention_mask.shape[1]), + mode="constant", + value=1, + ) + + first_caches_scores = outputs_cached.scores + outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=1) + full_cached_scores = first_caches_scores + outputs_cached.scores + outputs_cached.scores = full_cached_scores + + # The two sets of generated text and past kv should be equal to each other + self._check_similar_generate_outputs(outputs, outputs_cached) + for layer_idx in range(len(outputs_cached.past_key_values)): + for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])): + self.assertTrue( + torch.allclose( + outputs.past_key_values[layer_idx][kv_idx], + outputs_cached.past_key_values[layer_idx][kv_idx], + ) + ) + + # Based on tests.test_modeling_common.ModelTesterMixin.test_inputs_embeds_matches_input_ids + # Update encoder and decoder embeddings + def test_inputs_embeds_matches_input_ids(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model_class = self.model_tester.model_class + + model = model_class(config) + model.to(torch_device) + model.eval() + + model_forward_args = inspect.signature(model.forward).parameters + if "inputs_embeds" not in model_forward_args: + self.skipTest(reason="This model doesn't use `inputs_embeds`") + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1 + + encoder_embedding = model.get_encoder().get_input_embeddings() + decoder_embedding = model.get_decoder().get_input_embeddings() + + encoder_input_ids = inputs["input_ids"] + decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) + encoder_input_ids[encoder_input_ids == pad_token_id] = max(0, pad_token_id + 1) + decoder_input_ids[decoder_input_ids == pad_token_id] = max(0, pad_token_id + 1) + del inputs["input_ids"] + inputs.pop("decoder_input_ids", None) + + inputs_embeds = encoder_embedding(encoder_input_ids) + decoder_inputs_embeds = decoder_embedding(decoder_input_ids) + with torch.no_grad(): + out_ids = model(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids, **inputs)[0] + out_embeds = model(inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, **inputs)[0] + + torch.testing.assert_close(out_embeds, out_ids) + + # Based on tests.test_modeling_common.ModelTesterMixin.test_inputs_embeds_matches_input_ids + # Adjust token classiifcation + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + if model_class in [self.model_tester.for_token_class, self.model_tester.for_sequence_class]: + model = model_class(config, is_encoder_decoder=False) + else: + model = model_class(config) + + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + if hasattr(self.model_tester, "encoder_seq_length"): + seq_length = self.model_tester.encoder_seq_length + if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1: + seq_length = seq_length * self.model_tester.chunk_length + else: + seq_length = self.model_tester.seq_length + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [seq_length, self.model_tester.hidden_size], + ) + + if config.is_encoder_decoder: + hidden_states = outputs.decoder_hidden_states + + self.assertIsInstance(hidden_states, (list, tuple)) + self.assertEqual(len(hidden_states), expected_num_layers) + seq_len = getattr(self.model_tester, "seq_length", None) + decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len) + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [decoder_seq_length, self.model_tester.hidden_size], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + # Based on tests.models.t5.test_modeling_t5.T5ModelTest.test_custom_4d_attention_mask + # Excluding the final token from input_ids + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + _, + input_ids_shared_prefix, + mask_shared_prefix, + _, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward( + decoder_input_ids=input_ids, + input_ids=input_ids[:, :-1], + ).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + input_ids=input_ids[:1, :-1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + + # Based on tests.test_modeling_common.ModelTesterMixin.test_flex_attention_with_grads + # Update hidden size for encoder and decoder + @require_torch_gpu + def test_flex_attention_with_grads(self): + for model_class in self.all_model_classes: + # TODO: raushan, fix for composite models after making VLMs support new attn API + if not model_class._supports_flex_attn or self._is_composite: + self.skipTest(reason="This model does not support flex attention") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config._attn_implementation = "flex_attention" + # Flex Attention cannot use dropout + config.encoder.attention_dropout = 0 + config.decoder.attention_dropout = 0 + + # Flex attention relies on triton on compilation + # However, triton cannot handle hidden dimensions of less than 16 + # --> forcing at least a hidden dim of 16 + config.encoder.hidden_size *= max( + 16 + // getattr( + config.encoder, "head_dim", config.encoder.hidden_size // config.encoder.num_attention_heads + ), + 1, + ) + config.decoder.hidden_size *= max( + 16 + // getattr( + config.decoder, "head_dim", config.decoder.hidden_size // config.decoder.num_attention_heads + ), + 1, + ) + config.decoder.cross_attention_hidden_size = config.encoder.hidden_size + + config.decoder.head_dim = max(16, config.decoder.head_dim) + config.encoder.head_dim = max(16, config.encoder.head_dim) + + model = model_class(config).to(device=torch_device) + self.assertTrue(model.config._attn_implementation == "flex_attention") + + # Elaborate workaround for encoder-decoder models as some do not specify their main input + dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)} + if config.is_encoder_decoder: + dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"].to(torch_device) + dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"].to(torch_device) + + # If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605) + _ = model(**dummy_inputs) + + @unittest.skip("EncoderDecoderCache can't be gathered because it is not iterable.") + def test_multi_gpu_data_parallel_forward(self): + pass + + +class T5GemmaEncoderOnlyModelTester: + config_class = T5GemmaConfig + module_config_class = T5GemmaModuleConfig + + if is_torch_available(): + model_class = T5GemmaEncoderModel + + def __init__( + self, + parent, + batch_size=13, + is_training=True, + use_attention_mask=True, + use_labels=True, + vocab_size=99, + seq_length=7, + # default to encoders + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=37, + # common + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + scope=None, + # special ids + eos_token_id=1, + pad_token_id=0, + bos_token_id=2, + ): + self.parent = parent + self.batch_size = batch_size + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + # encoder + self.seq_length = seq_length + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.intermediate_size = intermediate_size + # common + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + self.head_dim = self.hidden_size // self.num_attention_heads + # special ids + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + + def get_encoder_config(self): + return self.module_config_class( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + head_dim=self.head_dim, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + ) + + def get_config(self): + return self.config_class( + encoder=self.get_encoder_config(), + decoder=None, + is_encoder_decoder=False, + # Used for generation test. + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + ) + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + # Remove BOS symbols from inputs. + input_ids = torch.where(input_ids == self.bos_token_id, 42, input_ids) + + attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + config = self.get_config() + + return ( + config, + input_ids, + attention_mask, + ) + + def create_and_check_model( + self, + config, + input_ids, + attention_mask, + ): + model = self.model_class(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids=input_ids, + attention_mask=attention_mask, + ) + result = model(input_ids=input_ids) + encoder_output = result.last_hidden_state + + self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_model_fp16_forward( + self, + config, + input_ids, + attention_mask, + ): + model = self.model_class(config=config).to(torch_device).half().eval() + output = model(input_ids, attention_mask=attention_mask)["last_hidden_state"] + self.parent.assertFalse(torch.isnan(output).any().item()) + + def create_and_check_with_token_classification_head( + self, + config, + input_ids, + attention_mask, + ): + labels = torch.tensor([1] * self.seq_length * self.batch_size, dtype=torch.long, device=torch_device) + model = T5GemmaForTokenClassification(config=config, is_encoder_decoder=False).to(torch_device).eval() + outputs = model( + input_ids=input_ids, + labels=labels, + attention_mask=attention_mask, + ) + self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.seq_length, config.num_labels)) + self.parent.assertEqual(outputs["loss"].size(), ()) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + attention_mask, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class T5GemmaEncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (T5GemmaEncoderModel, T5GemmaForTokenClassification) if is_torch_available() else () + test_pruning = False + test_resize_embeddings = False + test_headmasking = False + _is_stateful = True + is_encoder_decoder = False + model_split_percents = [0.4, 0.5] + + def setUp(self): + self.model_tester = T5GemmaEncoderOnlyModelTester(self) + self.config_tester = ConfigTester( + self, + config_class=T5GemmaConfig, + # For faking the testing. + hidden_size=37, + vocab_size=self.model_tester.vocab_size, + num_attention_heads=self.model_tester.num_attention_heads, + num_hidden_layers=self.model_tester.num_hidden_layers, + ) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skipIf(torch_device == "cpu", "Can't do half precision") + def test_model_fp16_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) + + def test_with_token_classification_head(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_with_token_classification_head(*config_and_inputs) + + @unittest.skip("No loss in the output of T5GemmaEncoderModel") + def test_training(self): + pass + + @unittest.skip("No loss in the output of T5GemmaEncoderModel") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip("No loss in the output of T5GemmaEncoderModel") + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip("No loss in the output of T5GemmaEncoderModel") + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + # Based on tests.test_modeling_common.ModelTesterMixin.test_flex_attention_with_grads + # Update hidden size for encoder + @require_torch_gpu + def test_flex_attention_with_grads(self): + for model_class in self.all_model_classes: + # TODO: raushan, fix for composite models after making VLMs support new attn API + if not model_class._supports_flex_attn or self._is_composite: + self.skipTest(reason="This model does not support flex attention") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config._attn_implementation = "flex_attention" + # Flex Attention cannot use dropout + config.encoder.attention_dropout = 0 + + # Flex attention relies on triton on compilation + # However, triton cannot handle hidden dimensions of less than 16 + # --> forcing at least a hidden dim of 16 + config.encoder.hidden_size *= max( + 16 + // getattr( + config.encoder, "head_dim", config.encoder.hidden_size // config.encoder.num_attention_heads + ), + 1, + ) + config.encoder.head_dim = max(16, config.encoder.head_dim) + + model = model_class(config).to(device=torch_device) + self.assertTrue(model.config._attn_implementation == "flex_attention") + + # Elaborate workaround for encoder-decoder models as some do not specify their main input + dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)} + + # If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605) + _ = model(**dummy_inputs) + + +# Based on tests.models.t5.test_modeling_t5.TestAsymmetricT5 +# Adapted for T5Gemma +@require_torch +class TestAsymmetricT5Gemma(unittest.TestCase): + def build_model_and_check_forward_pass(self, **kwargs): + tester = T5GemmaModelTester(self, **kwargs) + config, *inputs = tester.prepare_config_and_inputs() + ( + input_ids, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) = inputs + model = T5GemmaForConditionalGeneration(config=config).to(torch_device).eval() + outputs = model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + labels=lm_labels, + ) + # outputs = model(*inputs) + assert len(outputs) == 4 + assert outputs["logits"].size() == (tester.batch_size, tester.seq_length, tester.vocab_size) + assert outputs["loss"].size() == () + return model.model + + def test_small_decoder(self): + model = self.build_model_and_check_forward_pass(num_hidden_layers=1, encoder_num_hidden_layers=2) + assert len(model.encoder.layers) == 2 + assert len(model.decoder.layers) == 1 + + def test_defaulting_to_symmetry(self): + model = self.build_model_and_check_forward_pass(num_hidden_layers=2, encoder_num_hidden_layers=2) + assert len(model.decoder.layers) == len(model.encoder.layers) == 2