From 965e98dc54c4853b34b3a49810f24856a9fcdba8 Mon Sep 17 00:00:00 2001 From: Aritra Roy Gosthipaty Date: Thu, 23 May 2024 22:18:49 +0530 Subject: [PATCH] [Port] TensorFlow implementation of Mistral (#29708) * chore: initial commit * chore: adding imports and inits * chore: adding the causal and classification code * chore: adding names to the layers * chore: using single self attn layer * chore: built the model and layers * chore: start with testing * chore: docstring change, transpose fix * fix: rotary embedding * chore: adding cache implementation * remove unused torch * chore: fixing the indexing issue * make fix-copies * Use modeling_tf_utils.keras * make fixup * chore: fixing tests * chore: adding past key value logic * chore: adding multi label classfication test * fix: switching on the built parameters in the layers * fixing repo consistency * ruff formats * style changes * fix: tf and pt equivalence * removing returns from docstrings * fix docstrings * fix docstrings * removing todos * fix copies * fix docstring * fix docstring * chore: using easier rotate_half * adding integration tests * chore: addressing review related to rotary embedding layer * review changes * [run-slow] mistral * skip: test save load after resize token embedding * style --------- Co-authored-by: Matt --- docs/source/en/index.md | 2 +- docs/source/en/model_doc/mistral.md | 17 +- src/transformers/__init__.py | 9 + .../models/auto/modeling_tf_auto.py | 3 + src/transformers/models/mistral/__init__.py | 34 +- .../models/mistral/modeling_tf_mistral.py | 1055 +++++++++++++++++ src/transformers/utils/dummy_tf_objects.py | 28 + .../mistral/test_modeling_tf_mistral.py | 367 ++++++ 8 files changed, 1512 insertions(+), 3 deletions(-) create mode 100644 src/transformers/models/mistral/modeling_tf_mistral.py create mode 100644 tests/models/mistral/test_modeling_tf_mistral.py diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 9a8c2ebbe3..72237d1383 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -200,7 +200,7 @@ Flax), PyTorch, and/or TensorFlow. | [Megatron-BERT](model_doc/megatron-bert) | ✅ | ❌ | ❌ | | [Megatron-GPT2](model_doc/megatron_gpt2) | ✅ | ✅ | ✅ | | [MGP-STR](model_doc/mgp-str) | ✅ | ❌ | ❌ | -| [Mistral](model_doc/mistral) | ✅ | ❌ | ✅ | +| [Mistral](model_doc/mistral) | ✅ | ✅ | ✅ | | [Mixtral](model_doc/mixtral) | ✅ | ❌ | ❌ | | [mLUKE](model_doc/mluke) | ✅ | ❌ | ❌ | | [MMS](model_doc/mms) | ✅ | ✅ | ✅ | diff --git a/docs/source/en/model_doc/mistral.md b/docs/source/en/model_doc/mistral.md index d4bc761060..17ce15b2b8 100644 --- a/docs/source/en/model_doc/mistral.md +++ b/docs/source/en/model_doc/mistral.md @@ -216,4 +216,19 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h ## FlaxMistralForCausalLM [[autodoc]] FlaxMistralForCausalLM - - __call__ \ No newline at end of file + - __call__ + +## TFMistralModel + +[[autodoc]] TFMistralModel + - call + +## TFMistralForCausalLM + +[[autodoc]] TFMistralForCausalLM + - call + +## TFMistralForSequenceClassification + +[[autodoc]] TFMistralForSequenceClassification + - call \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 8da7a8b3e3..fc8f6b1a9c 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3974,6 +3974,9 @@ else: _import_structure["models.mbart"].extend( ["TFMBartForConditionalGeneration", "TFMBartModel", "TFMBartPreTrainedModel"] ) + _import_structure["models.mistral"].extend( + ["TFMistralForCausalLM", "TFMistralForSequenceClassification", "TFMistralModel", "TFMistralPreTrainedModel"] + ) _import_structure["models.mobilebert"].extend( [ "TFMobileBertForMaskedLM", @@ -8067,6 +8070,12 @@ if TYPE_CHECKING: TFMBartModel, TFMBartPreTrainedModel, ) + from .models.mistral import ( + TFMistralForCausalLM, + TFMistralForSequenceClassification, + TFMistralModel, + TFMistralPreTrainedModel, + ) from .models.mobilebert import ( TFMobileBertForMaskedLM, TFMobileBertForMultipleChoice, diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index 2004495756..906fe411d0 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -65,6 +65,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict( ("lxmert", "TFLxmertModel"), ("marian", "TFMarianModel"), ("mbart", "TFMBartModel"), + ("mistral", "TFMistralModel"), ("mobilebert", "TFMobileBertModel"), ("mobilevit", "TFMobileViTModel"), ("mpnet", "TFMPNetModel"), @@ -178,6 +179,7 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("gpt-sw3", "TFGPT2LMHeadModel"), ("gpt2", "TFGPT2LMHeadModel"), ("gptj", "TFGPTJForCausalLM"), + ("mistral", "TFMistralForCausalLM"), ("openai-gpt", "TFOpenAIGPTLMHeadModel"), ("opt", "TFOPTForCausalLM"), ("rembert", "TFRemBertForCausalLM"), @@ -320,6 +322,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("layoutlm", "TFLayoutLMForSequenceClassification"), ("layoutlmv3", "TFLayoutLMv3ForSequenceClassification"), ("longformer", "TFLongformerForSequenceClassification"), + ("mistral", "TFMistralForSequenceClassification"), ("mobilebert", "TFMobileBertForSequenceClassification"), ("mpnet", "TFMPNetForSequenceClassification"), ("openai-gpt", "TFOpenAIGPTForSequenceClassification"), diff --git a/src/transformers/models/mistral/__init__.py b/src/transformers/models/mistral/__init__.py index abf1e32a4b..93e551e193 100644 --- a/src/transformers/models/mistral/__init__.py +++ b/src/transformers/models/mistral/__init__.py @@ -13,7 +13,13 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_torch_available +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_flax_available, + is_tf_available, + is_torch_available, +) _import_structure = { @@ -47,6 +53,19 @@ else: "FlaxMistralPreTrainedModel", ] +try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_tf_mistral"] = [ + "TFMistralModel", + "TFMistralForCausalLM", + "TFMistralForSequenceClassification", + "TFMistralPreTrainedModel", + ] + if TYPE_CHECKING: from .configuration_mistral import MistralConfig @@ -77,6 +96,19 @@ if TYPE_CHECKING: FlaxMistralPreTrainedModel, ) + try: + if not is_tf_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_tf_mistral import ( + TFMistralForCausalLM, + TFMistralForSequenceClassification, + TFMistralModel, + TFMistralPreTrainedModel, + ) + else: import sys diff --git a/src/transformers/models/mistral/modeling_tf_mistral.py b/src/transformers/models/mistral/modeling_tf_mistral.py new file mode 100644 index 0000000000..3215439802 --- /dev/null +++ b/src/transformers/models/mistral/modeling_tf_mistral.py @@ -0,0 +1,1055 @@ +# coding=utf-8 +# Copyright 2024 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 Mistral model.""" + +import math +import warnings +from typing import List, Optional, Tuple, Union + +import tensorflow as tf + +from ...modeling_tf_outputs import ( + TFBaseModelOutputWithPast, + TFCausalLMOutputWithPast, + TFSequenceClassifierOutputWithPast, +) +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFPreTrainedModel, + TFSequenceClassificationLoss, + get_initializer, + get_tf_activation, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_mistral import MistralConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MistralConfig" + + +def _make_causal_mask(input_ids_shape, dtype, past_key_values_length=0): + """ + Make causal mask used for bi-directional self-attention, supporting both static and dynamic shapes. + """ + bsz, tgt_len = input_ids_shape + + # Create a matrix where only the lower triangle and diagonal are filled with zeros (causal mask) + mask = tf.fill((tgt_len, tgt_len), tf.dtypes.as_dtype(dtype).min) + mask_cond = tf.range(tgt_len) + mask = tf.where(mask_cond[:, None] >= mask_cond[None, :], 0.0, mask) + + if past_key_values_length > 0: + mask = tf.concat([tf.zeros((tgt_len, past_key_values_length), dtype=dtype), mask], axis=-1) + + if bsz is None: + # When batch size is dynamic, expand and tile + # so we can compile a functional model + mask = tf.expand_dims(mask, 0) + mask = tf.expand_dims(mask, 0) # shape: (1, 1, tgt_len, tgt_len + past_key_values_length) + mask = tf.tile(mask, [bsz, 1, 1, 1]) + else: + # When batch size is static, directly use broadcast_to + mask = tf.broadcast_to(mask[None, None, :, :], (bsz, 1, tgt_len, tgt_len + past_key_values_length)) + + return mask + + +def _expand_mask(mask, dtype, tgt_len=None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = shape_list(mask) + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = tf.expand_dims(tf.expand_dims(mask, 1), 1) + expanded_mask = tf.broadcast_to(expanded_mask, [bsz, 1, tgt_len, src_len]) + + inverted_mask = 1.0 - tf.cast(expanded_mask, dtype) + + return tf.where( + tf.cast(inverted_mask, bool), tf.fill(dims=shape_list(inverted_mask), value=tf.float32.min), inverted_mask + ) + + +class TFMistralRMSNorm(keras.layers.Layer): + def __init__(self, hidden_size, eps=1e-6, **kwargs): + """ + TFMistralRMSNorm is equivalent to T5LayerNorm + """ + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.variance_epsilon = eps + + def build(self, input_shape=None): + self.weight = self.add_weight( + name="weight", + shape=self.hidden_size, + initializer="ones", + ) + if self.built: + return + self.built = True + + def call(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = tf.cast(hidden_states, tf.float32) + variance = tf.reduce_mean(tf.square(hidden_states), axis=-1, keepdims=True) + hidden_states = tf.divide(hidden_states, tf.sqrt(variance + self.variance_epsilon)) + return self.weight * tf.cast(hidden_states, input_dtype) + + +# Verification: https://colab.research.google.com/gist/ariG23498/f8d8131b795a131b93d99e70ee93c192/scratchpad.ipynb +class TFMistralRotaryEmbedding(keras.layers.Layer): + def __init__(self, dim, max_position_embeddings=2048, base=10000, **kwargs): + super().__init__(**kwargs) + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.inv_freq = 1.0 / (self.base ** (tf.range(start=0, limit=self.dim, delta=2, dtype=tf.float32) / self.dim)) + + def call(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + t = tf.cast(tf.range(seq_len, dtype=tf.int64), self.inv_freq.dtype) + freqs = tf.einsum("i,j->ij", t, self.inv_freq) + emb = tf.concat([freqs, freqs], axis=-1) + cos_values = tf.cast(tf.cos(emb), x.dtype) + sin_values = tf.cast(tf.sin(emb), x.dtype) + + cos_values = cos_values[:seq_len] + cos_values = tf.cast(cos_values, dtype=x.dtype) + sin_values = sin_values[:seq_len] + sin_values = tf.cast(sin_values, dtype=x.dtype) + return (cos_values, sin_values) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + mid_length = shape_list(x)[-1] // 2 + x1 = x[..., :mid_length] + x2 = x[..., mid_length:] + return tf.concat([-x2, x1], axis=-1) + + +# Verification: https://colab.research.google.com/gist/ariG23498/bb8474baeb33f4ae6ed7d77da5f7e7a4/scratchpad.ipynb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`tf.Tensor`): The query tensor. + k (`tf.Tensor`): The key tensor. + cos (`tf.Tensor`): The cosine part of the rotary embedding. + sin (`tf.Tensor`): The sine part of the rotary embedding. + position_ids (`tf.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(tf.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = tf.expand_dims(tf.gather(cos, position_ids), unsqueeze_dim) + sin = tf.expand_dims(tf.gather(sin, position_ids), unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class TFMistralMLP(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = keras.layers.Dense(self.intermediate_size, use_bias=False, name="gate_proj") + self.up_proj = keras.layers.Dense(self.intermediate_size, use_bias=False, name="up_proj") + self.down_proj = keras.layers.Dense(self.hidden_size, use_bias=False, name="down_proj") + self.act_fn = get_tf_activation(config.hidden_act) + + def call(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "gate_proj", None) is not None: + with tf.name_scope(self.gate_proj.name): + self.gate_proj.build((self.hidden_size,)) + if getattr(self, "up_proj", None) is not None: + with tf.name_scope(self.up_proj.name): + self.up_proj.build((self.hidden_size,)) + if getattr(self, "down_proj", None) is not None: + with tf.name_scope(self.down_proj.name): + self.down_proj.build((self.intermediate_size,)) + + +# Verification: https://colab.research.google.com/gist/ariG23498/556d443d491966763ce2e7eee336efed/scratchpad.ipynb +def repeat_kv(hidden_states: tf.Tensor, n_rep: int) -> tf.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = shape_list(hidden_states) + if n_rep == 1: + return hidden_states + hidden_states = tf.expand_dims(hidden_states, 2) + hidden_states = tf.repeat(hidden_states, repeats=n_rep, axis=2) + return tf.reshape(hidden_states, (batch, num_key_value_heads * n_rep, slen, head_dim)) + + +class TFMistralAttention(keras.layers.Layer): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None, **kwargs): + super().__init__(**kwargs) + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = keras.layers.Dense(self.num_heads * self.head_dim, use_bias=False, name="q_proj") + self.k_proj = keras.layers.Dense(self.num_key_value_heads * self.head_dim, use_bias=False, name="k_proj") + self.v_proj = keras.layers.Dense(self.num_key_value_heads * self.head_dim, use_bias=False, name="v_proj") + self.o_proj = keras.layers.Dense(self.hidden_size, use_bias=False, name="o_proj") + + self.rotary_emb = TFMistralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + name="rotary_emb", + ) + self.dropout = keras.layers.Dropout(rate=self.attention_dropout) + + def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): + tensor = tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)) + tensor = tf.transpose(tensor, perm=(0, 2, 1, 3)) + return tensor + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: Optional[tf.Tensor] = None, + position_ids: Optional[tf.Tensor] = None, + past_key_value: Optional[Tuple[tf.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + training=None, + **kwargs, + ) -> Tuple[tf.Tensor, Optional[tf.Tensor], Optional[Tuple[tf.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = shape_list(hidden_states) + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = tf.transpose( + tf.reshape(query_states, (bsz, q_len, self.num_heads, self.head_dim)), perm=(0, 2, 1, 3) + ) + key_states = tf.transpose( + tf.reshape(key_states, (bsz, q_len, self.num_key_value_heads, self.head_dim)), perm=(0, 2, 1, 3) + ) + value_states = tf.transpose( + tf.reshape(value_states, (bsz, q_len, self.num_key_value_heads, self.head_dim)), perm=(0, 2, 1, 3) + ) + + kv_seq_len = shape_list(key_states)[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb( + x=value_states, + seq_len=kv_seq_len, + ) + query_states, key_states = apply_rotary_pos_emb( + q=query_states, + k=key_states, + cos=cos, + sin=sin, + position_ids=position_ids, + ) + + if past_key_value is not None: + # resue k, v, self_attention + key_states = tf.concat([past_key_value[0], key_states], axis=2) + value_states = tf.concat([past_key_value[1], value_states], axis=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = tf.matmul(query_states, key_states, transpose_b=True) / math.sqrt(self.head_dim) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = stable_softmax(attn_weights, axis=-1) + attn_weights = tf.cast(attn_weights, query_states.dtype) + attn_weights = self.dropout( + attn_weights, + training=training, + ) + attn_output = tf.matmul(attn_weights, value_states) + + attn_output = tf.transpose(attn_output, perm=(0, 2, 1, 3)) + attn_output = tf.reshape(attn_output, (bsz, q_len, self.hidden_size)) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "q_proj", None) is not None: + with tf.name_scope(self.q_proj.name): + self.q_proj.build((self.hidden_size,)) + if getattr(self, "k_proj", None) is not None: + with tf.name_scope(self.k_proj.name): + self.k_proj.build((self.hidden_size,)) + if getattr(self, "v_proj", None) is not None: + with tf.name_scope(self.v_proj.name): + self.v_proj.build((self.hidden_size,)) + if getattr(self, "o_proj", None) is not None: + with tf.name_scope(self.o_proj.name): + self.o_proj.build((self.num_heads * self.head_dim,)) + + +class TFMistralDecoderLayer(keras.layers.Layer): + def __init__(self, config: MistralConfig, layer_idx: int, **kwargs): + super().__init__(**kwargs) + self.hidden_size = config.hidden_size + + self.self_attn = TFMistralAttention(config, layer_idx, name="self_attn") + + self.mlp = TFMistralMLP(config, name="mlp") + self.input_layernorm = TFMistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="input_layernorm") + self.post_attention_layernorm = TFMistralRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, name="post_attention_layernorm" + ) + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: Optional[tf.Tensor] = None, + position_ids: Optional[tf.Tensor] = None, + past_key_value: Optional[Tuple[tf.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[tf.Tensor, Optional[Tuple[tf.Tensor, tf.Tensor]]]: + """ + Args: + hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`tf.Tensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(tf.Tensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attn", None) is not None: + with tf.name_scope(self.self_attn.name): + self.self_attn.build(None) + if getattr(self, "mlp", None) is not None: + with tf.name_scope(self.mlp.name): + self.mlp.build(None) + if getattr(self, "input_layernorm", None) is not None: + with tf.name_scope(self.input_layernorm.name): + self.input_layernorm.build(None) + if getattr(self, "post_attention_layernorm", None) is not None: + with tf.name_scope(self.post_attention_layernorm.name): + self.post_attention_layernorm.build(None) + + +@keras_serializable +class TFMistralMainLayer(keras.layers.Layer): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`] + + Args: + config: MistralConfig + """ + + config_class = MistralConfig + + def __init__(self, config: MistralConfig, **kwargs): + super().__init__(**kwargs) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + + # TF and PT Embedding check: https://colab.research.google.com/gist/ariG23498/2b9826818875c9c4968c79cb19f55f2c/scratchpad.ipynb + self.embed_tokens = keras.layers.Embedding( + input_dim=config.vocab_size, + output_dim=config.hidden_size, + name="embed_tokens", + ) + self.layers = [ + TFMistralDecoderLayer(config, layer_idx, name=f"layers.{layer_idx}") + for layer_idx in range(config.num_hidden_layers) + ] + self._attn_implementation = config._attn_implementation + self.norm = TFMistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="norm") + self.config = config + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + # if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + @unpack_inputs + def call( + self, + input_ids: tf.Tensor = None, + attention_mask: Optional[tf.Tensor] = None, + position_ids: Optional[tf.Tensor] = None, + past_key_values: Optional[List[tf.Tensor]] = None, + inputs_embeds: Optional[tf.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TFBaseModelOutputWithPast]: + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = shape_list(input_ids) + elif inputs_embeds is not None: + batch_size, seq_length, _ = shape_list(inputs_embeds) + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = shape_list(past_key_values[0][0])[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + position_ids = tf.range( + start=past_key_values_length, limit=seq_length + past_key_values_length, dtype=tf.int64 + ) + position_ids = tf.reshape(tf.expand_dims(position_ids, 0), (-1, seq_length)) + + else: + position_ids = tf.cast(tf.reshape(position_ids, (-1, seq_length)), tf.int64) + + if inputs_embeds is None: + check_embeddings_within_bounds(input_ids, self.config.vocab_size) + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is None: + attention_mask = tf.ones((batch_size, seq_length_with_past), dtype=tf.bool) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return TFBaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embed_tokens", None) is not None: + with tf.name_scope(self.embed_tokens.name): + self.embed_tokens.build(None) + if getattr(self, "norm", None) is not None: + with tf.name_scope(self.norm.name): + self.norm.build(None) + if getattr(self, "layers", None) is not None: + for layer in self.layers: + with tf.name_scope(layer.name): + layer.build(None) + + +MISTRAL_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `model` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Parameters: + config ([`MistralConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Mistral Model outputting raw hidden-states without any specific head on top.", + MISTRAL_START_DOCSTRING, +) +class TFMistralPreTrainedModel(TFPreTrainedModel): + config_class = MistralConfig + base_model_prefix = "model" + + +MISTRAL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(tf.Tensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Mistral Model outputting raw hidden-states without any specific head on top.", + MISTRAL_START_DOCSTRING, +) +class TFMistralModel(TFMistralPreTrainedModel): + def __init__(self, config: MistralConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.model = TFMistralMainLayer(config, name="model") + + @unpack_inputs + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + def call( + self, + input_ids: tf.Tensor = None, + attention_mask: Optional[tf.Tensor] = None, + position_ids: Optional[tf.Tensor] = None, + past_key_values: Optional[List[tf.Tensor]] = None, + inputs_embeds: Optional[tf.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TFBaseModelOutputWithPast]: + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "model", None) is not None: + with tf.name_scope(self.model.name): + self.model.build(None) + + +class TFMistralForCausalLM(TFMistralPreTrainedModel, TFCausalLanguageModelingLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.model = TFMistralMainLayer(config, name="model") + self.vocab_size = config.vocab_size + self.lm_head = keras.layers.Dense( + config.vocab_size, + use_bias=False, + kernel_initializer=get_initializer(config.initializer_range), + name="lm_head", + ) + self.config = config + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @unpack_inputs + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def call( + self, + input_ids: tf.Tensor = None, + attention_mask: Optional[tf.Tensor] = None, + position_ids: Optional[tf.Tensor] = None, + past_key_values: Optional[List[tf.Tensor]] = None, + inputs_embeds: Optional[tf.Tensor] = None, + labels: Optional[tf.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TFCausalLMOutputWithPast]: + r""" + Args: + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` + or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = tf.cast(logits, tf.float32) + + loss = None + if labels is not None: + # shift labels to the left and cut last logit token + shifted_logits = logits[:, :-1] + labels = labels[:, 1:] + loss = self.hf_compute_loss(labels, shifted_logits) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + # Omit tokens covered by past_key_values + if past_key_values: + input_ids = tf.expand_dims(input_ids[:, -1], -1) + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True) + if past_key_values: + position_ids = tf.expand_dims(position_ids[:, -1], -1) + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + } + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "model", None) is not None: + with tf.name_scope(self.model.name): + self.model.build(None) + if getattr(self, "lm_head", None) is not None: + with tf.name_scope(self.lm_head.name): + self.lm_head.build((self.config.hidden_size,)) + + +@add_start_docstrings( + """ + The Mistral Model transformer with a sequence classification head on top (linear layer). + + [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + MISTRAL_START_DOCSTRING, +) +class TFMistralForSequenceClassification(TFMistralPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + self.model = TFMistralMainLayer(config, name="model") + self.score = keras.layers.Dense( + self.num_labels, + use_bias=False, + kernel_initializer=get_initializer(config.initializer_range), + name="score", + ) + self.config = config + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @unpack_inputs + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def call( + self, + input_ids: tf.Tensor = None, + attention_mask: Optional[tf.Tensor] = None, + position_ids: Optional[tf.Tensor] = None, + past_key_values: Optional[List[tf.Tensor]] = None, + inputs_embeds: Optional[tf.Tensor] = None, + labels: Optional[tf.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TFSequenceClassifierOutputWithPast]: + r""" + Args: + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + + transformer_outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + logits_shape = shape_list(logits) + in_logits = None + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = ( + tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1) + - 1 + ) + sequence_lengths = tf.where( + sequence_lengths >= 0, + sequence_lengths, + tf.cast(shape_list(input_ids[-1]), sequence_lengths.dtype) - 1, + ) + in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + loss = None + + if labels is not None: + if self.config.pad_token_id is None and logits_shape[0] != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + + if not tf.is_tensor(sequence_lengths): + in_logits = logits[0 : logits_shape[0], sequence_lengths] + + loss = self.hf_compute_loss(tf.reshape(labels, [-1]), tf.reshape(in_logits, [-1, self.num_labels])) + pooled_logits = in_logits if in_logits is not None else logits + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "model", None) is not None: + with tf.name_scope(self.model.name): + self.model.build(None) + if getattr(self, "score", None) is not None: + with tf.name_scope(self.score.name): + self.score.build((self.config.hidden_size,)) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index e0b396c716..337b0938b3 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -1801,6 +1801,34 @@ class TFMBartPreTrainedModel(metaclass=DummyObject): requires_backends(self, ["tf"]) +class TFMistralForCausalLM(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFMistralForSequenceClassification(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFMistralModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFMistralPreTrainedModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFMobileBertForMaskedLM(metaclass=DummyObject): _backends = ["tf"] diff --git a/tests/models/mistral/test_modeling_tf_mistral.py b/tests/models/mistral/test_modeling_tf_mistral.py new file mode 100644 index 0000000000..df07e96bb1 --- /dev/null +++ b/tests/models/mistral/test_modeling_tf_mistral.py @@ -0,0 +1,367 @@ +# coding=utf-8 +# Copyright 2024 Mistral AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the TF 2.0 Mistral model.""" + +import unittest + +import numpy as np + +from transformers import AutoTokenizer, MistralConfig, is_tf_available +from transformers.testing_utils import ( + require_tf, + slow, +) + +from ...generation.test_tf_utils import TFGenerationIntegrationTests +from ...test_configuration_common import ConfigTester +from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_tf_available(): + import tensorflow as tf + + from transformers.models.mistral.modeling_tf_mistral import ( + TFMistralForCausalLM, + TFMistralForSequenceClassification, + TFMistralModel, + ) + + +class TFMistralModelTester: + def __init__(self, parent): + self.parent = parent + self.batch_size = 13 + self.seq_length = 7 + self.is_training = True + self.use_input_mask = True + self.use_token_type_ids = False + self.use_labels = True + self.vocab_size = 99 + self.hidden_size = 32 + self.num_hidden_layers = 2 + self.num_attention_heads = 4 + self.num_key_value_heads = 2 + self.intermediate_size = 37 + self.hidden_act = "gelu" + self.hidden_dropout_prob = 0.1 + self.attention_probs_dropout_prob = 0.1 + self.max_position_embeddings = 512 + self.type_vocab_size = 16 + self.type_sequence_label_size = 2 + self.initializer_range = 0.02 + self.num_labels = 3 + self.num_choices = 4 + self.pad_token_id = 0 + self.scope = None + self.bos_token_id = self.vocab_size - 1 + self.eos_token_id = self.vocab_size - 1 + self.pad_token_id = self.vocab_size - 1 + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length], self.vocab_size) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = MistralConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + pad_token_id=self.pad_token_id, + ) + + return ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) + + def create_and_check_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = TFMistralModel(config=config) + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + model = TFMistralModel(config) + result = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + ) + result = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + ) + result = model(input_ids, attention_mask=input_mask) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_causal_lm( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + model = TFMistralForCausalLM(config=config) + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.is_decoder = True + config.add_cross_attention = True + model = TFMistralForCausalLM(config=config) + + # first forward pass + outputs = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + next_attention_mask = tf.concat([input_mask, next_mask], axis=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=True, + )["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(np.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_tf +class TFMistralModelTest(TFModelTesterMixin, TFGenerationIntegrationTests, PipelineTesterMixin, unittest.TestCase): + all_model_classes = ( + (TFMistralModel, TFMistralForCausalLM, TFMistralForSequenceClassification) if is_tf_available() else () + ) + all_generative_model_classes = (TFMistralForCausalLM,) if is_tf_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": TFMistralModel, + "text-classification": TFMistralForSequenceClassification, + "text-generation": TFMistralForCausalLM, + "zero-shot": TFMistralForSequenceClassification, + } + if is_tf_available() + else {} + ) + test_onnx = False + test_pruning = False + test_missing_keys = False + test_head_masking = False + + # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 + def is_pipeline_test_to_skip( + self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name + ): + return True + + def setUp(self): + self.model_tester = TFMistralModelTester(self) + self.config_tester = ConfigTester(self, config_class=MistralConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_various_embeddings(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + for type in ["absolute", "relative_key", "relative_key_query"]: + config_and_inputs[0].position_embedding_type = type + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_Mistral_sequence_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = tf.not_equal(input_ids, 1) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = TFMistralForSequenceClassification(config) + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_Mistral_sequence_classification_model_for_single_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "single_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = tf.not_equal(input_ids, 1) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = TFMistralForSequenceClassification(config) + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_Mistral_sequence_classification_model_for_multi_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "multi_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = tf.not_equal(input_ids, 1) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = TFMistralForSequenceClassification(config) + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + @unittest.skip("Mistral buffers include complex numbers, which breaks this test") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip("Mistral uses GQA on all models so the KV cache is a non standard format") + def test_past_key_values_format(self): + pass + + @unittest.skip("Vocab resizing is not supported") + def test_save_load_after_resize_token_embeddings(self): + pass + + +@require_tf +class TFMistralIntegrationTest(unittest.TestCase): + @slow + def test_model_7b_logits(self): + input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] + model = TFMistralForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-MistralForCausalLM", from_pt=True + ) + input_ids = tf.constant([input_ids]) + out = model(input_ids).logits + # Expected mean on dim = -1 + EXPECTED_MEAN = tf.constant( + [[-1.281e-04, -2.869e-04, -9.989e-05, -8.995e-05, 2.494e-04, -3.083e-04, -2.672e-04, -1.239e-04]] + ) + tf.debugging.assert_near(tf.reduce_mean(out, axis=-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2) + # slicing logits[0, 0, 0:30] + EXPECTED_SLICE = tf.constant([0.1033, 0.1493, -0.0041, -0.0021, -0.1686, 0.0356, 0.0812, 0.2218, -0.1257, 0.1920, 0.0929, 0.1181, 0.0111, 0.0395, -0.0064, 0.1712, -0.0751, 0.0625, -0.2409, 0.1541, -0.1271, -0.2296, -0.0099, -0.0160, 0.0311, -0.0824, -0.1518, 0.0722, 0.0187, 0.0484]) # fmt: skip + tf.debugging.assert_near(out[0, 0, :30], EXPECTED_SLICE, atol=1e-4, rtol=1e-4) + + @slow + def test_model_7b_generation(self): + EXPECTED_TEXT_COMPLETION = """My favourite condiment is Werk a EgyadjustPrintfigiousPDFPHPct guns Ein motor conceti barSequ内 infrastructure millretval""" + prompt = "My favourite condiment is " + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM", use_fast=False) + model = TFMistralForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-MistralForCausalLM", from_pt=True + ) + input_ids = tokenizer.encode(prompt, return_tensors="tf") + + # greedy generation outputs + generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text)