From 21aecc097162649c198ffbb37dbef17991314b0e Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Tue, 4 Jan 2022 13:23:10 +0100 Subject: [PATCH] Add Flax RoFormer (#15005) * Add FlaxRoFormer * Clean code + make quality * Fix output pooling for FlaxRoFormerForMultipleChoiceModule * Apply suggestions from code review * add flax model to repos Co-authored-by: Patrick von Platen --- docs/source/index.mdx | 2 +- docs/source/model_doc/roformer.mdx | 30 + src/transformers/__init__.py | 20 + .../models/auto/modeling_flax_auto.py | 7 + src/transformers/models/roformer/__init__.py | 27 +- .../models/roformer/modeling_flax_roformer.py | 1086 +++++++++++++++++ src/transformers/utils/dummy_flax_objects.py | 84 ++ tests/test_modeling_flax_roformer.py | 163 +++ 8 files changed, 1417 insertions(+), 2 deletions(-) create mode 100644 src/transformers/models/roformer/modeling_flax_roformer.py create mode 100644 tests/test_modeling_flax_roformer.py diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 8001ce56c4..b18487a081 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -246,7 +246,7 @@ Flax), PyTorch, and/or TensorFlow. | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ | -| RoFormer | ✅ | ✅ | ✅ | ✅ | ❌ | +| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ | | SegFormer | ❌ | ❌ | ✅ | ❌ | ❌ | | SEW | ❌ | ❌ | ✅ | ❌ | ❌ | | SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/model_doc/roformer.mdx b/docs/source/model_doc/roformer.mdx index a00da9ef9b..435941d9f2 100644 --- a/docs/source/model_doc/roformer.mdx +++ b/docs/source/model_doc/roformer.mdx @@ -123,3 +123,33 @@ This model was contributed by [junnyu](https://huggingface.co/junnyu). The origi [[autodoc]] TFRoFormerForQuestionAnswering - call + +## FlaxRoFormerModel + +[[autodoc]] FlaxRoFormerModel + - __call__ + +## FlaxRoFormerForMaskedLM + +[[autodoc]] FlaxRoFormerForMaskedLM + - __call__ + +## FlaxRoFormerForSequenceClassification + +[[autodoc]] FlaxRoFormerForSequenceClassification + - __call__ + +## FlaxRoFormerForMultipleChoice + +[[autodoc]] FlaxRoFormerForMultipleChoice + - __call__ + +## FlaxRoFormerForTokenClassification + +[[autodoc]] FlaxRoFormerForTokenClassification + - __call__ + +## FlaxRoFormerForQuestionAnswering + +[[autodoc]] FlaxRoFormerForQuestionAnswering + - __call__ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 93ae3d1ffc..de36b81688 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2084,6 +2084,17 @@ if is_flax_available(): "FlaxRobertaPreTrainedModel", ] ) + _import_structure["models.roformer"].extend( + [ + "FlaxRoFormerForMaskedLM", + "FlaxRoFormerForMultipleChoice", + "FlaxRoFormerForQuestionAnswering", + "FlaxRoFormerForSequenceClassification", + "FlaxRoFormerForTokenClassification", + "FlaxRoFormerModel", + "FlaxRoFormerPreTrainedModel", + ] + ) _import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"]) _import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel") _import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"]) @@ -3819,6 +3830,15 @@ if TYPE_CHECKING: FlaxRobertaModel, FlaxRobertaPreTrainedModel, ) + from .models.roformer import ( + FlaxRoFormerForMaskedLM, + FlaxRoFormerForMultipleChoice, + FlaxRoFormerForQuestionAnswering, + FlaxRoFormerForSequenceClassification, + FlaxRoFormerForTokenClassification, + FlaxRoFormerModel, + FlaxRoFormerPreTrainedModel, + ) from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 0191648f0e..dbe49bfc3f 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -50,6 +50,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict( ("wav2vec2", "FlaxWav2Vec2Model"), ("marian", "FlaxMarianModel"), ("blenderbot", "FlaxBlenderbotModel"), + ("roformer", "FlaxRoFormerModel"), ] ) @@ -66,6 +67,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ("t5", "FlaxT5ForConditionalGeneration"), ("mt5", "FlaxMT5ForConditionalGeneration"), ("wav2vec2", "FlaxWav2Vec2ForPreTraining"), + ("roformer", "FlaxRoFormerForMaskedLM"), ] ) @@ -80,6 +82,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( ("bart", "FlaxBartForConditionalGeneration"), ("electra", "FlaxElectraForMaskedLM"), ("mbart", "FlaxMBartForConditionalGeneration"), + ("roformer", "FlaxRoFormerForMaskedLM"), ] ) @@ -132,6 +135,7 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("bart", "FlaxBartForSequenceClassification"), ("electra", "FlaxElectraForSequenceClassification"), ("mbart", "FlaxMBartForSequenceClassification"), + ("roformer", "FlaxRoFormerForSequenceClassification"), ] ) @@ -146,6 +150,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ("bart", "FlaxBartForQuestionAnswering"), ("electra", "FlaxElectraForQuestionAnswering"), ("mbart", "FlaxMBartForQuestionAnswering"), + ("roformer", "FlaxRoFormerForQuestionAnswering"), ] ) @@ -158,6 +163,7 @@ FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("bert", "FlaxBertForTokenClassification"), ("big_bird", "FlaxBigBirdForTokenClassification"), ("electra", "FlaxElectraForTokenClassification"), + ("roformer", "FlaxRoFormerForTokenClassification"), ] ) @@ -170,6 +176,7 @@ FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( ("bert", "FlaxBertForMultipleChoice"), ("big_bird", "FlaxBigBirdForMultipleChoice"), ("electra", "FlaxElectraForMultipleChoice"), + ("roformer", "FlaxRoFormerForMultipleChoice"), ] ) diff --git a/src/transformers/models/roformer/__init__.py b/src/transformers/models/roformer/__init__.py index c405d5aab6..85853a3c74 100644 --- a/src/transformers/models/roformer/__init__.py +++ b/src/transformers/models/roformer/__init__.py @@ -17,7 +17,7 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...file_utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available +from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available _import_structure = { @@ -59,6 +59,19 @@ if is_tf_available(): ] +if is_flax_available(): + _import_structure["modeling_flax_roformer"] = [ + "FLAX_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", + "FlaxRoFormerForMaskedLM", + "FlaxRoFormerForMultipleChoice", + "FlaxRoFormerForQuestionAnswering", + "FlaxRoFormerForSequenceClassification", + "FlaxRoFormerForTokenClassification", + "FlaxRoFormerModel", + "FlaxRoFormerPreTrainedModel", + ] + + if TYPE_CHECKING: from .configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig from .tokenization_roformer import RoFormerTokenizer @@ -95,6 +108,18 @@ if TYPE_CHECKING: TFRoFormerPreTrainedModel, ) + if is_flax_available(): + from .modeling_flax_roformer import ( + FLAX_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, + FlaxRoFormerForMaskedLM, + FlaxRoFormerForMultipleChoice, + FlaxRoFormerForQuestionAnswering, + FlaxRoFormerForSequenceClassification, + FlaxRoFormerForTokenClassification, + FlaxRoFormerModel, + FlaxRoFormerPreTrainedModel, + ) + else: import sys diff --git a/src/transformers/models/roformer/modeling_flax_roformer.py b/src/transformers/models/roformer/modeling_flax_roformer.py new file mode 100644 index 0000000000..70c8b93009 --- /dev/null +++ b/src/transformers/models/roformer/modeling_flax_roformer.py @@ -0,0 +1,1086 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Flax RoFormer model.""" + +from typing import Callable, Optional, Tuple + +import numpy as np + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict +from flax.linen.attention import dot_product_attention_weights +from jax import lax + +from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxMaskedLMOutput, + FlaxMultipleChoiceModelOutput, + FlaxQuestionAnsweringModelOutput, + FlaxSequenceClassifierOutput, + FlaxTokenClassifierOutput, +) +from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, overwrite_call_docstring +from ...utils import logging +from .configuration_roformer import RoFormerConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "junnyu/roformer_chinese_base" +_CONFIG_FOR_DOC = "RoFormerConfig" +_TOKENIZER_FOR_DOC = "RoFormerTokenizer" + +FLAX_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "junnyu/roformer_chinese_small", + "junnyu/roformer_chinese_base", + "junnyu/roformer_chinese_char_small", + "junnyu/roformer_chinese_char_base", + "junnyu/roformer_small_discriminator", + "junnyu/roformer_small_generator" + # See all RoFormer models at https://huggingface.co/models?filter=roformer +] + + +ROFORMER_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from PyTorch models) + + This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`RoFormerConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +ROFORMER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`numpy.ndarray` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`RoFormerTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`numpy.ndarray` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`numpy.ndarray` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + head_mask (`numpy.ndarray` of shape `({0})`, `optional): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.marian.modeling_flax_marian.create_sinusoidal_positions +def create_sinusoidal_positions(n_pos, dim): + position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) + sentinel = dim // 2 + dim % 2 + out = np.zeros_like(position_enc) + out[:, 0:sentinel] = np.sin(position_enc[:, 0::2]) + out[:, sentinel:] = np.cos(position_enc[:, 1::2]) + + return jnp.array(out) + + +class FlaxRoFormerEmbeddings(nn.Module): + """Construct the embeddings from word and token_type embeddings.""" + + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.word_embeddings = nn.Embed( + self.config.vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + self.token_type_embeddings = nn.Embed( + self.config.type_vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, input_ids, token_type_ids, attention_mask, deterministic: bool = True): + # Embed + inputs_embeds = self.word_embeddings(input_ids.astype("i4")) + token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) + + # Sum all embeddings + hidden_states = inputs_embeds + token_type_embeddings + + # Layer Norm + hidden_states = self.LayerNorm(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +class FlaxRoFormerSelfAttention(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self) -> None: + if self.config.hidden_size % self.config.num_attention_heads != 0: + raise ValueError( + "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\ + : {self.config.num_attention_heads}" + ) + + self.query = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.key = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.value = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + self.rotary_value = self.config.rotary_value + + def __call__( + self, + hidden_states, + attention_mask, + sinusoidal_pos, + layer_head_mask, + deterministic=True, + output_attentions: bool = False, + ): + head_dim = self.config.hidden_size // self.config.num_attention_heads + + query_states = self.query(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + value_states = self.value(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + key_states = self.key(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + + if sinusoidal_pos is not None: + if self.rotary_value: + query_states, key_states, value_states = self.apply_rotary_position_embeddings( + sinusoidal_pos, query_states, key_states, value_states + ) + else: + query_states, key_states = self.apply_rotary_position_embeddings( + sinusoidal_pos, query_states, key_states + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, -1e10).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.config.attention_probs_dropout_prob > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_probs_dropout_prob, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + @staticmethod + def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None): + sin, cos = sinusoidal_pos.split(2, axis=-1) + sin_pos = jnp.stack([sin, sin], axis=-1).reshape(sinusoidal_pos.shape) + cos_pos = jnp.stack([cos, cos], axis=-1).reshape(sinusoidal_pos.shape) + + def rotate_layer(layer, sin_pos, cos_pos): + rotate_half_layer = jnp.stack([-layer[..., 1::2], layer[..., ::2]], axis=-1).reshape(layer.shape) + rotary_matrix_cos = jnp.einsum("bslh,...sh->bslh", layer, cos_pos) + rotary_matrix_sin = jnp.einsum("bslh,...sh->bslh", rotate_half_layer, sin_pos) + return rotary_matrix_cos + rotary_matrix_sin + + query_layer = rotate_layer(query_layer, sin_pos, cos_pos) + key_layer = rotate_layer(key_layer, sin_pos, cos_pos) + if value_layer is not None: + value_layer = rotate_layer(value_layer, sin_pos, cos_pos) + return query_layer, key_layer, value_layer + return query_layer, key_layer + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->RoFormer +class FlaxRoFormerSelfOutput(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, input_tensor, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class FlaxRoFormerAttention(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.self = FlaxRoFormerSelfAttention(self.config, dtype=self.dtype) + self.output = FlaxRoFormerSelfOutput(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + sinusoidal_pos, + layer_head_mask, + deterministic=True, + output_attentions: bool = False, + ): + # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) + # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable + # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) + attn_outputs = self.self( + hidden_states, + attention_mask, + sinusoidal_pos, + layer_head_mask=layer_head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] + hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_outputs[1],) + + return outputs + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->RoFormer +class FlaxRoFormerIntermediate(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.intermediate_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.activation = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->RoFormer +class FlaxRoFormerOutput(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__(self, hidden_states, attention_output, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.LayerNorm(hidden_states + attention_output) + return hidden_states + + +class FlaxRoFormerLayer(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.attention = FlaxRoFormerAttention(self.config, dtype=self.dtype) + self.intermediate = FlaxRoFormerIntermediate(self.config, dtype=self.dtype) + self.output = FlaxRoFormerOutput(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + sinusiodal_pos, + layer_head_mask, + deterministic: bool = True, + output_attentions: bool = False, + ): + attention_outputs = self.attention( + hidden_states, + attention_mask, + sinusiodal_pos, + layer_head_mask=layer_head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = attention_outputs[0] + + hidden_states = self.intermediate(attention_output) + hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attention_outputs[1],) + return outputs + + +class FlaxRoFormerLayerCollection(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxRoFormerLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + attention_mask, + sinusoidal_pos, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + # Check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.shape[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for \ + {head_mask.shape[0]}." + ) + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer( + hidden_states, + attention_mask, + sinusoidal_pos, + layer_head_mask=head_mask[i] if head_mask is not None else None, + deterministic=deterministic, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states,) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +class FlaxRoFormerEncoder(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.embed_positions = create_sinusoidal_positions( + self.config.max_position_embeddings, self.config.hidden_size // self.config.num_attention_heads + ) + self.layer = FlaxRoFormerLayerCollection(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + sinusoidal_pos = self.embed_positions[: hidden_states.shape[1], :] + + return self.layer( + hidden_states, + attention_mask, + sinusoidal_pos, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPredictionHeadTransform with Bert->RoFormer +class FlaxRoFormerPredictionHeadTransform(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) + self.activation = ACT2FN[self.config.hidden_act] + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + return self.LayerNorm(hidden_states) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLMPredictionHead with Bert->RoFormer +class FlaxRoFormerLMPredictionHead(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 + bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros + + def setup(self): + self.transform = FlaxRoFormerPredictionHeadTransform(self.config, dtype=self.dtype) + self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False) + self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) + + def __call__(self, hidden_states, shared_embedding=None): + hidden_states = self.transform(hidden_states) + + if shared_embedding is not None: + hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + hidden_states = self.decoder(hidden_states) + + bias = jnp.asarray(self.bias, self.dtype) + hidden_states += bias + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOnlyMLMHead with Bert->RoFormer +class FlaxRoFormerOnlyMLMHead(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.predictions = FlaxRoFormerLMPredictionHead(self.config, dtype=self.dtype) + + def __call__(self, hidden_states, shared_embedding=None): + hidden_states = self.predictions(hidden_states, shared_embedding=shared_embedding) + return hidden_states + + +class FlaxRoFormerClassificationHead(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.out_proj = nn.Dense( + self.config.num_labels, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + def __call__(self, hidden_states, deterministic=True): + hidden_states = hidden_states[:, 0, :] # take token (equiv. to [CLS]) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.dense(hidden_states) + hidden_states = nn.tanh(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class FlaxRoFormerPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RoFormerConfig + base_model_prefix = "roformer" + module_class: nn.Module = None + + def __init__( + self, + config: RoFormerConfig, + input_shape: Tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + **kwargs + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + token_type_ids = jnp.zeros_like(input_ids) + attention_mask = jnp.ones_like(input_ids) + head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + return self.module.init(rngs, input_ids, attention_mask, token_type_ids, head_mask, return_dict=False)[ + "params" + ] + + @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + head_mask=None, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # init input tensors if not passed + if token_type_ids is None: + token_type_ids = jnp.zeros_like(input_ids) + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + if head_mask is None: + head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + jnp.array(token_type_ids, dtype="i4"), + jnp.array(head_mask, dtype="i4"), + not train, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + ) + + +class FlaxRoFormerModule(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.embeddings = FlaxRoFormerEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxRoFormerEncoder(self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + hidden_states = self.embeddings(input_ids, token_type_ids, attention_mask, deterministic=deterministic) + outputs = self.encoder( + hidden_states, + attention_mask, + head_mask=head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + + if not return_dict: + return (hidden_states,) + outputs[1:] + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + "The bare RoFormer Model transformer outputting raw hidden-states without any specific head on top.", + ROFORMER_START_DOCSTRING, +) +class FlaxRoFormerModel(FlaxRoFormerPreTrainedModel): + module_class = FlaxRoFormerModule + + +append_call_sample_docstring( + FlaxRoFormerModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC +) + + +class FlaxRoFormerForMaskedLMModule(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype) + self.cls = FlaxRoFormerOnlyMLMHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roformer( + input_ids, + attention_mask, + token_type_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.tie_word_embeddings: + shared_embedding = self.roformer.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + + # Compute the prediction scores + logits = self.cls(hidden_states, shared_embedding=shared_embedding) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxMaskedLMOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings("""RoFormer Model with a `language modeling` head on top.""", ROFORMER_START_DOCSTRING) +class FlaxRoFormerForMaskedLM(FlaxRoFormerPreTrainedModel): + module_class = FlaxRoFormerForMaskedLMModule + + +append_call_sample_docstring( + FlaxRoFormerForMaskedLM, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxMaskedLMOutput, + _CONFIG_FOR_DOC, + mask="", +) + + +class FlaxRoFormerForSequenceClassificationModule(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype) + self.classifier = FlaxRoFormerClassificationHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roformer( + input_ids, + attention_mask, + token_type_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + logits = self.classifier(sequence_output, deterministic=deterministic) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxSequenceClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + RoFormer Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + ROFORMER_START_DOCSTRING, +) +class FlaxRoFormerForSequenceClassification(FlaxRoFormerPreTrainedModel): + module_class = FlaxRoFormerForSequenceClassificationModule + + +append_call_sample_docstring( + FlaxRoFormerForSequenceClassification, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxSequenceClassifierOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxRoFormerForMultipleChoiceModule(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.classifier = nn.Dense(1, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + num_choices = input_ids.shape[1] + input_ids = input_ids.reshape(-1, input_ids.shape[-1]) + attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) + token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) + + # Model + outputs = self.roformer( + input_ids, + attention_mask, + token_type_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # Equivalent to sequence_summary call in the PyTorch implementation + hidden_states = outputs[0] + pooled_output = hidden_states[:, -1] + pooled_output = self.dropout(pooled_output, deterministic=deterministic) + + logits = self.classifier(pooled_output) + + reshaped_logits = logits.reshape(-1, num_choices) + + if not return_dict: + return (reshaped_logits,) + outputs[2:] + + return FlaxMultipleChoiceModelOutput( + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + RoFormer Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + ROFORMER_START_DOCSTRING, +) +class FlaxRoFormerForMultipleChoice(FlaxRoFormerPreTrainedModel): + module_class = FlaxRoFormerForMultipleChoiceModule + + +overwrite_call_docstring( + FlaxRoFormerForMultipleChoice, ROFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") +) +append_call_sample_docstring( + FlaxRoFormerForMultipleChoice, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxMultipleChoiceModelOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxRoFormerForTokenClassificationModule(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roformer( + input_ids, + attention_mask, + token_type_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + logits = self.classifier(hidden_states) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxTokenClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + RoFormer Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + ROFORMER_START_DOCSTRING, +) +class FlaxRoFormerForTokenClassification(FlaxRoFormerPreTrainedModel): + module_class = FlaxRoFormerForTokenClassificationModule + + +append_call_sample_docstring( + FlaxRoFormerForTokenClassification, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxTokenClassifierOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxRoFormerForQuestionAnsweringModule(nn.Module): + config: RoFormerConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype) + self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roformer( + input_ids, + attention_mask, + token_type_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + if not return_dict: + return (start_logits, end_logits) + outputs[1:] + + return FlaxQuestionAnsweringModelOutput( + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + RoFormer Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + ROFORMER_START_DOCSTRING, +) +class FlaxRoFormerForQuestionAnswering(FlaxRoFormerPreTrainedModel): + module_class = FlaxRoFormerForQuestionAnsweringModule + + +append_call_sample_docstring( + FlaxRoFormerForQuestionAnswering, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxQuestionAnsweringModelOutput, + _CONFIG_FOR_DOC, +) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index c3a817dc7c..52c0e5e124 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -1316,6 +1316,90 @@ class FlaxRobertaPreTrainedModel: requires_backends(self, ["flax"]) +class FlaxRoFormerForMaskedLM: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + def __call__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxRoFormerForMultipleChoice: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + def __call__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxRoFormerForQuestionAnswering: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + def __call__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxRoFormerForSequenceClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + def __call__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxRoFormerForTokenClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + def __call__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxRoFormerModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + def __call__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxRoFormerPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + def __call__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxT5ForConditionalGeneration: def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) diff --git a/tests/test_modeling_flax_roformer.py b/tests/test_modeling_flax_roformer.py new file mode 100644 index 0000000000..cd1af7caf6 --- /dev/null +++ b/tests/test_modeling_flax_roformer.py @@ -0,0 +1,163 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from transformers import RoFormerConfig, is_flax_available +from transformers.testing_utils import require_flax, slow + +from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask + + +if is_flax_available(): + import jax.numpy as jnp + from transformers.models.roformer.modeling_flax_roformer import ( + FlaxRoFormerForMaskedLM, + FlaxRoFormerForMultipleChoice, + FlaxRoFormerForQuestionAnswering, + FlaxRoFormerForSequenceClassification, + FlaxRoFormerForTokenClassification, + FlaxRoFormerModel, + ) + + +class FlaxRoFormerModelTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_attention_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_choices=4, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_choices = num_choices + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + config = RoFormerConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + ) + + return config, input_ids, token_type_ids, attention_mask + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, token_type_ids, attention_mask = config_and_inputs + inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask} + return config, inputs_dict + + +@require_flax +class FlaxRoFormerModelTest(FlaxModelTesterMixin, unittest.TestCase): + + test_head_masking = True + + all_model_classes = ( + ( + FlaxRoFormerModel, + FlaxRoFormerForMaskedLM, + FlaxRoFormerForSequenceClassification, + FlaxRoFormerForTokenClassification, + FlaxRoFormerForMultipleChoice, + FlaxRoFormerForQuestionAnswering, + ) + if is_flax_available() + else () + ) + + def setUp(self): + self.model_tester = FlaxRoFormerModelTester(self) + + @slow + def test_model_from_pretrained(self): + for model_class_name in self.all_model_classes: + model = model_class_name.from_pretrained("junnyu/roformer_chinese_small", from_pt=True) + outputs = model(np.ones((1, 1))) + self.assertIsNotNone(outputs) + + +@require_flax +class FlaxRoFormerModelIntegrationTest(unittest.TestCase): + @slow + def test_inference_masked_lm(self): + model = FlaxRoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base") + input_ids = jnp.array([[0, 1, 2, 3, 4, 5]]) + output = model(input_ids)[0] + + vocab_size = 50000 + + expected_shape = (1, 6, vocab_size) + self.assertEqual(output.shape, expected_shape) + + expected_slice = jnp.array( + [[[-0.1205, -1.0265, 0.2922], [-1.5134, 0.1974, 0.1519], [-5.0135, -3.9003, -0.8404]]] + ) + + self.assertTrue(jnp.allclose(output[:, :3, :3], expected_slice, atol=1e-4))