diff --git a/docs/source/index.rst b/docs/source/index.rst index 7eb2ae606b..9c837a5e73 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -321,7 +321,7 @@ Flax), PyTorch, and/or TensorFlow. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | Model | Tokenizer slow | Tokenizer fast | PyTorch support | TensorFlow support | Flax Support | +=============================+================+================+=================+====================+==============+ -| ALBERT | ✅ | ✅ | ✅ | ✅ | ❌ | +| ALBERT | ✅ | ✅ | ✅ | ✅ | ✅ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | BART | ✅ | ✅ | ✅ | ✅ | ✅ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ diff --git a/docs/source/model_doc/albert.rst b/docs/source/model_doc/albert.rst index c4b4eac02d..66fab82c10 100644 --- a/docs/source/model_doc/albert.rst +++ b/docs/source/model_doc/albert.rst @@ -43,7 +43,8 @@ Tips: similar to a BERT-like architecture with the same number of hidden layers as it has to iterate through the same number of (repeating) layers. -This model was contributed by `lysandre `__. The original code can be found `here +This model was contributed by `lysandre `__. This model jax version was contributed by +`kamalkraj `__. The original code can be found `here `__. AlbertConfig @@ -174,3 +175,52 @@ TFAlbertForQuestionAnswering .. autoclass:: transformers.TFAlbertForQuestionAnswering :members: call + + +FlaxAlbertModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxAlbertModel + :members: __call__ + + +FlaxAlbertForPreTraining +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxAlbertForPreTraining + :members: __call__ + + +FlaxAlbertForMaskedLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxAlbertForMaskedLM + :members: __call__ + + +FlaxAlbertForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxAlbertForSequenceClassification + :members: __call__ + + +FlaxAlbertForMultipleChoice +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxAlbertForMultipleChoice + :members: __call__ + + +FlaxAlbertForTokenClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxAlbertForTokenClassification + :members: __call__ + + +FlaxAlbertForQuestionAnswering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxAlbertForQuestionAnswering + :members: __call__ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 0a8ed48f86..57e44bfc73 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1642,6 +1642,18 @@ if is_flax_available(): "FlaxTopPLogitsWarper", ] _import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"] + _import_structure["models.albert"].extend( + [ + "FlaxAlbertForMaskedLM", + "FlaxAlbertForMultipleChoice", + "FlaxAlbertForPreTraining", + "FlaxAlbertForQuestionAnswering", + "FlaxAlbertForSequenceClassification", + "FlaxAlbertForTokenClassification", + "FlaxAlbertModel", + "FlaxAlbertPreTrainedModel", + ] + ) _import_structure["models.auto"].extend( [ "FLAX_MODEL_FOR_CAUSAL_LM_MAPPING", @@ -3152,6 +3164,16 @@ if TYPE_CHECKING: FlaxTopPLogitsWarper, ) from .modeling_flax_utils import FlaxPreTrainedModel + from .models.albert import ( + FlaxAlbertForMaskedLM, + FlaxAlbertForMultipleChoice, + FlaxAlbertForPreTraining, + FlaxAlbertForQuestionAnswering, + FlaxAlbertForSequenceClassification, + FlaxAlbertForTokenClassification, + FlaxAlbertModel, + FlaxAlbertPreTrainedModel, + ) from .models.auto import ( FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, diff --git a/src/transformers/models/albert/__init__.py b/src/transformers/models/albert/__init__.py index 44e43eb30b..183140c6e0 100644 --- a/src/transformers/models/albert/__init__.py +++ b/src/transformers/models/albert/__init__.py @@ -20,6 +20,7 @@ from typing import TYPE_CHECKING from ...file_utils import ( _LazyModule, + is_flax_available, is_sentencepiece_available, is_tf_available, is_tokenizers_available, @@ -65,6 +66,17 @@ if is_tf_available(): "TFAlbertPreTrainedModel", ] +if is_flax_available(): + _import_structure["modeling_flax_albert"] = [ + "FlaxAlbertForMaskedLM", + "FlaxAlbertForMultipleChoice", + "FlaxAlbertForPreTraining", + "FlaxAlbertForQuestionAnswering", + "FlaxAlbertForSequenceClassification", + "FlaxAlbertForTokenClassification", + "FlaxAlbertModel", + "FlaxAlbertPreTrainedModel", + ] if TYPE_CHECKING: from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig, AlbertOnnxConfig @@ -103,6 +115,17 @@ if TYPE_CHECKING: TFAlbertPreTrainedModel, ) + if is_flax_available(): + from .modeling_flax_albert import ( + FlaxAlbertForMaskedLM, + FlaxAlbertForMultipleChoice, + FlaxAlbertForPreTraining, + FlaxAlbertForQuestionAnswering, + FlaxAlbertForSequenceClassification, + FlaxAlbertForTokenClassification, + FlaxAlbertModel, + FlaxAlbertPreTrainedModel, + ) else: import sys diff --git a/src/transformers/models/albert/modeling_flax_albert.py b/src/transformers/models/albert/modeling_flax_albert.py new file mode 100644 index 0000000000..5fcec8d5bf --- /dev/null +++ b/src/transformers/models/albert/modeling_flax_albert.py @@ -0,0 +1,1109 @@ +# coding=utf-8 +# Copyright 2021 Google AI, Google Brain and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Tuple + +import numpy as np + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict +from flax.linen.attention import dot_product_attention_weights +from jax import lax + +from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPooling, + FlaxMaskedLMOutput, + FlaxMultipleChoiceModelOutput, + FlaxQuestionAnsweringModelOutput, + FlaxSequenceClassifierOutput, + FlaxTokenClassifierOutput, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import logging +from .configuration_albert import AlbertConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "albert-base-v2" +_CONFIG_FOR_DOC = "AlbertConfig" +_TOKENIZER_FOR_DOC = "AlbertTokenizer" + + +@flax.struct.dataclass +class FlaxAlbertForPreTrainingOutput(ModelOutput): + """ + Output type of :class:`~transformers.FlaxAlbertForPreTraining`. + + Args: + prediction_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + sop_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + prediction_logits: jnp.ndarray = None + sop_logits: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + + +ALBERT_START_DOCSTRING = r""" + + This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the + generic methods the library implements for all its model (such as downloading, saving and converting weights from + PyTorch models) + + This model is also a Flax Linen `flax.linen.Module + `__ subclass. Use it as a regular Flax linen Module + and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - `Just-In-Time (JIT) compilation `__ + - `Automatic Differentiation `__ + - `Vectorization `__ + - `Parallelization `__ + + Parameters: + config (:class:`~transformers.AlbertConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the + model weights. +""" + +ALBERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.AlbertTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + `What are token type IDs? <../glossary.html#token-type-ids>`__ + position_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + +""" + + +class FlaxAlbertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.word_embeddings = nn.Embed( + self.config.vocab_size, + self.config.embedding_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.position_embeddings = nn.Embed( + self.config.max_position_embeddings, + self.config.embedding_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.token_type_embeddings = nn.Embed( + self.config.type_vocab_size, + self.config.embedding_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings.__call__ + def __call__(self, input_ids, token_type_ids, position_ids, deterministic: bool = True): + # Embed + inputs_embeds = self.word_embeddings(input_ids.astype("i4")) + position_embeds = self.position_embeddings(position_ids.astype("i4")) + token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) + + # Sum all embeddings + hidden_states = inputs_embeds + token_type_embeddings + position_embeds + + # Layer Norm + hidden_states = self.LayerNorm(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +class FlaxAlbertSelfAttention(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.hidden_size % self.config.num_attention_heads != 0: + raise ValueError( + "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\ + : {self.config.num_attention_heads}" + ) + + self.query = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + ) + self.key = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + ) + self.value = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + ) + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + dtype=self.dtype, + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False): + head_dim = self.config.hidden_size // self.config.num_attention_heads + + query_states = self.query(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + value_states = self.value(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + key_states = self.key(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, -1e10).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.config.attention_probs_dropout_prob > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_probs_dropout_prob, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) + + projected_attn_output = self.dense(attn_output) + projected_attn_output = self.dropout(projected_attn_output, deterministic=deterministic) + layernormed_attn_output = self.LayerNorm(projected_attn_output + hidden_states) + outputs = (layernormed_attn_output, attn_weights) if output_attentions else (layernormed_attn_output,) + return outputs + + +class FlaxAlbertLayer(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.attention = FlaxAlbertSelfAttention(self.config, dtype=self.dtype) + self.ffn = nn.Dense( + self.config.intermediate_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + dtype=self.dtype, + ) + self.activation = ACT2FN[self.config.hidden_act] + self.ffn_output = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + dtype=self.dtype, + ) + self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + ): + attention_outputs = self.attention( + hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions + ) + attention_output = attention_outputs[0] + ffn_output = self.ffn(attention_output) + ffn_output = self.activation(ffn_output) + ffn_output = self.ffn_output(ffn_output) + ffn_output = self.dropout(ffn_output, deterministic=deterministic) + hidden_states = self.full_layer_layer_norm(ffn_output + attention_output) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attention_outputs[1],) + return outputs + + +class FlaxAlbertLayers(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + layer_index: Optional[str] = None + + def setup(self): + self.albert_layers = FlaxAlbertLayer(self.config, name=self.layer_index, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + ): + outputs = self.albert_layers( + hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions + ) + return outputs + + +class FlaxAlbertLayerCollection(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxAlbertLayers(self.config, name="albert_layers", layer_index=str(i), dtype=self.dtype) + for i in range(self.config.inner_group_num) + ] + + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + ): + layer_hidden_states = () + layer_attentions = () + + for layer_index, albert_layer in enumerate(self.layers): + layer_output = albert_layer( + hidden_states, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + ) + hidden_states = layer_output[0] + + if output_attentions: + layer_attentions = layer_attentions + (layer_output[1],) + + if output_hidden_states: + layer_hidden_states = layer_hidden_states + (hidden_states,) + + outputs = (hidden_states,) + if output_hidden_states: + outputs = outputs + (layer_hidden_states,) + if output_attentions: + outputs = outputs + (layer_attentions,) + return outputs # last-layer hidden state, (layer hidden states), (layer attentions) + + +class FlaxAlbertLayerCollections(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + layer_index: Optional[str] = None + + def setup(self): + self.albert_layer_groups = FlaxAlbertLayerCollection(self.config, name=self.layer_index, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + ): + outputs = self.albert_layer_groups( + hidden_states, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + return outputs + + +class FlaxAlbertEncoder(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups) + self.embedding_hidden_mapping_in = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + dtype=self.dtype, + ) + self.layers = [ + FlaxAlbertLayerCollections(self.config, name="albert_layer_groups", layer_index=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_groups) + ] + + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + hidden_states = self.embedding_hidden_mapping_in(hidden_states) + all_attentions = () if output_attentions else None + all_hidden_states = (hidden_states,) if output_hidden_states else None + + for i in range(self.config.num_hidden_layers): + # Index of the hidden group + group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups)) + layer_group_output = self.layers[group_idx]( + hidden_states, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + hidden_states = layer_group_output[0] + + if output_attentions: + all_attentions = all_attentions + layer_group_output[-1] + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +class FlaxAlbertOnlyMLMHead(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 + bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros + + def setup(self): + self.dense = nn.Dense(self.config.embedding_size, dtype=self.dtype) + self.activation = ACT2FN[self.config.hidden_act] + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False) + self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) + + def __call__(self, hidden_states, shared_embedding=None): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + + if shared_embedding is not None: + hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + hidden_states = self.decoder(hidden_states) + + hidden_states += self.bias + return hidden_states + + +class FlaxAlbertSOPHead(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dropout = nn.Dropout(self.config.classifier_dropout_prob) + self.classifier = nn.Dense(2, dtype=self.dtype) + + def __call__(self, pooled_output, deterministic=True): + pooled_output = self.dropout(pooled_output, deterministic=deterministic) + logits = self.classifier(pooled_output) + return logits + + +class FlaxAlbertPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = AlbertConfig + base_model_prefix = "albert" + module_class: nn.Module = None + + def __init__( + self, + config: AlbertConfig, + input_shape: Tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + **kwargs + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + token_type_ids = jnp.zeros_like(input_ids) + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) + attention_mask = jnp.ones_like(input_ids) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False)[ + "params" + ] + + @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # init input tensors if not passed + if token_type_ids is None: + token_type_ids = jnp.zeros_like(input_ids) + + if position_ids is None: + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + jnp.array(token_type_ids, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + not train, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + ) + + +class FlaxAlbertModule(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + add_pooling_layer: bool = True + + def setup(self): + self.embeddings = FlaxAlbertEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxAlbertEncoder(self.config, dtype=self.dtype) + if self.add_pooling_layer: + self.pooler = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + dtype=self.dtype, + name="pooler", + ) + self.pooler_activation = nn.tanh + else: + self.pooler = None + self.pooler_activation = None + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids: Optional[np.ndarray] = None, + position_ids: Optional[np.ndarray] = None, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # make sure `token_type_ids` is correctly initialized when not passed + if token_type_ids is None: + token_type_ids = jnp.zeros_like(input_ids) + + # make sure `position_ids` is correctly initialized when not passed + if position_ids is None: + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + hidden_states = self.embeddings(input_ids, token_type_ids, position_ids, deterministic=deterministic) + + outputs = self.encoder( + hidden_states, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + if self.add_pooling_layer: + pooled = self.pooler(hidden_states[:, 0]) + pooled = self.pooler_activation(pooled) + else: + pooled = None + + if not return_dict: + # if pooled is None, don't return it + if pooled is None: + return (hidden_states,) + outputs[1:] + return (hidden_states, pooled) + outputs[1:] + + return FlaxBaseModelOutputWithPooling( + last_hidden_state=hidden_states, + pooler_output=pooled, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + "The bare Albert Model transformer outputting raw hidden-states without any specific head on top.", + ALBERT_START_DOCSTRING, +) +class FlaxAlbertModel(FlaxAlbertPreTrainedModel): + module_class = FlaxAlbertModule + + +append_call_sample_docstring( + FlaxAlbertModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC +) + + +class FlaxAlbertForPreTrainingModule(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype) + self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype) + self.sop_classifier = FlaxAlbertSOPHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + + # Model + outputs = self.albert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.tie_word_embeddings: + shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + + hidden_states = outputs[0] + pooled_output = outputs[1] + + prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding) + sop_scores = self.sop_classifier(pooled_output, deterministic=deterministic) + + if not return_dict: + return (prediction_scores, sop_scores) + outputs[2:] + + return FlaxAlbertForPreTrainingOutput( + prediction_logits=prediction_scores, + sop_logits=sop_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a + `sentence order prediction (classification)` head. + """, + ALBERT_START_DOCSTRING, +) +class FlaxAlbertForPreTraining(FlaxAlbertPreTrainedModel): + module_class = FlaxAlbertForPreTrainingModule + + +FLAX_ALBERT_FOR_PRETRAINING_DOCSTRING = """ + Returns: + + Example:: + + >>> from transformers import AlbertTokenizer, FlaxAlbertForPreTraining + + >>> tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2') + >>> model = FlaxAlbertForPreTraining.from_pretrained('albert-base-v2') + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.prediction_logits + >>> seq_relationship_logits = outputs.sop_logits +""" + +overwrite_call_docstring( + FlaxAlbertForPreTraining, + ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_ALBERT_FOR_PRETRAINING_DOCSTRING, +) +append_replace_return_docstrings( + FlaxAlbertForPreTraining, output_type=FlaxAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC +) + + +class FlaxAlbertForMaskedLMModule(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.albert = FlaxAlbertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.albert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.tie_word_embeddings: + shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + + # Compute the prediction scores + logits = self.predictions(hidden_states, shared_embedding=shared_embedding) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxMaskedLMOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings("""Albert Model with a `language modeling` head on top. """, ALBERT_START_DOCSTRING) +class FlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel): + module_class = FlaxAlbertForMaskedLMModule + + +append_call_sample_docstring( + FlaxAlbertForMaskedLM, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC +) + + +class FlaxAlbertForSequenceClassificationModule(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype) + classifier_dropout = ( + self.config.classifier_dropout_prob + if self.config.classifier_dropout_prob is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(rate=classifier_dropout) + self.classifier = nn.Dense( + self.config.num_labels, + dtype=self.dtype, + ) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.albert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output, deterministic=deterministic) + logits = self.classifier(pooled_output) + + if not return_dict: + return (logits,) + outputs[2:] + + return FlaxSequenceClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + ALBERT_START_DOCSTRING, +) +class FlaxAlbertForSequenceClassification(FlaxAlbertPreTrainedModel): + module_class = FlaxAlbertForSequenceClassificationModule + + +append_call_sample_docstring( + FlaxAlbertForSequenceClassification, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxSequenceClassifierOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxAlbertForMultipleChoiceModule(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.classifier = nn.Dense(1, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + num_choices = input_ids.shape[1] + input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None + attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None + token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None + position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None + + # Model + outputs = self.albert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output, deterministic=deterministic) + logits = self.classifier(pooled_output) + + reshaped_logits = logits.reshape(-1, num_choices) + + if not return_dict: + return (reshaped_logits,) + outputs[2:] + + return FlaxMultipleChoiceModelOutput( + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + ALBERT_START_DOCSTRING, +) +class FlaxAlbertForMultipleChoice(FlaxAlbertPreTrainedModel): + module_class = FlaxAlbertForMultipleChoiceModule + + +overwrite_call_docstring( + FlaxAlbertForMultipleChoice, ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") +) +append_call_sample_docstring( + FlaxAlbertForMultipleChoice, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxMultipleChoiceModelOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxAlbertForTokenClassificationModule(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + classifier_dropout = ( + self.config.classifier_dropout_prob + if self.config.classifier_dropout_prob is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(rate=classifier_dropout) + self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.albert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + logits = self.classifier(hidden_states) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxTokenClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + ALBERT_START_DOCSTRING, +) +class FlaxAlbertForTokenClassification(FlaxAlbertPreTrainedModel): + module_class = FlaxAlbertForTokenClassificationModule + + +append_call_sample_docstring( + FlaxAlbertForTokenClassification, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxTokenClassifierOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxAlbertForQuestionAnsweringModule(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.albert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + if not return_dict: + return (start_logits, end_logits) + outputs[1:] + + return FlaxQuestionAnsweringModelOutput( + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + ALBERT_START_DOCSTRING, +) +class FlaxAlbertForQuestionAnswering(FlaxAlbertPreTrainedModel): + module_class = FlaxAlbertForQuestionAnsweringModule + + +append_call_sample_docstring( + FlaxAlbertForQuestionAnswering, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxQuestionAnsweringModelOutput, + _CONFIG_FOR_DOC, +) diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index f6f260b496..a0bfb609d3 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -29,6 +29,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict( [ # Base model mapping ("distilbert", "FlaxDistilBertModel"), + ("albert", "FlaxAlbertModel"), ("roberta", "FlaxRobertaModel"), ("bert", "FlaxBertModel"), ("big_bird", "FlaxBigBirdModel"), @@ -49,6 +50,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict( FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( [ # Model for pre-training mapping + ("albert", "FlaxAlbertForPreTraining"), ("roberta", "FlaxRobertaForMaskedLM"), ("bert", "FlaxBertForPreTraining"), ("big_bird", "FlaxBigBirdForPreTraining"), @@ -65,6 +67,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( [ # Model for Masked LM mapping ("distilbert", "FlaxDistilBertForMaskedLM"), + ("albert", "FlaxAlbertForMaskedLM"), ("roberta", "FlaxRobertaForMaskedLM"), ("bert", "FlaxBertForMaskedLM"), ("big_bird", "FlaxBigBirdForMaskedLM"), @@ -104,6 +107,7 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Sequence Classification mapping ("distilbert", "FlaxDistilBertForSequenceClassification"), + ("albert", "FlaxAlbertForSequenceClassification"), ("roberta", "FlaxRobertaForSequenceClassification"), ("bert", "FlaxBertForSequenceClassification"), ("big_bird", "FlaxBigBirdForSequenceClassification"), @@ -117,6 +121,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ # Model for Question Answering mapping ("distilbert", "FlaxDistilBertForQuestionAnswering"), + ("albert", "FlaxAlbertForQuestionAnswering"), ("roberta", "FlaxRobertaForQuestionAnswering"), ("bert", "FlaxBertForQuestionAnswering"), ("big_bird", "FlaxBigBirdForQuestionAnswering"), @@ -130,6 +135,7 @@ FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Token Classification mapping ("distilbert", "FlaxDistilBertForTokenClassification"), + ("albert", "FlaxAlbertForTokenClassification"), ("roberta", "FlaxRobertaForTokenClassification"), ("bert", "FlaxBertForTokenClassification"), ("big_bird", "FlaxBigBirdForTokenClassification"), @@ -141,6 +147,7 @@ FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( [ # Model for Multiple Choice mapping ("distilbert", "FlaxDistilBertForMultipleChoice"), + ("albert", "FlaxAlbertForMultipleChoice"), ("roberta", "FlaxRobertaForMultipleChoice"), ("bert", "FlaxBertForMultipleChoice"), ("big_bird", "FlaxBigBirdForMultipleChoice"), diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index c3ee8b7a27..da9eef8749 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -76,6 +76,74 @@ class FlaxPreTrainedModel: requires_backends(cls, ["flax"]) +class FlaxAlbertForMaskedLM: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxAlbertForMultipleChoice: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxAlbertForPreTraining: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxAlbertForQuestionAnswering: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxAlbertForSequenceClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxAlbertForTokenClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxAlbertModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + +class FlaxAlbertPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = None diff --git a/tests/test_modeling_flax_albert.py b/tests/test_modeling_flax_albert.py new file mode 100644 index 0000000000..ca8cfeafdd --- /dev/null +++ b/tests/test_modeling_flax_albert.py @@ -0,0 +1,161 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from transformers import AlbertConfig, is_flax_available +from transformers.testing_utils import require_flax, slow + +from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask + + +if is_flax_available(): + import jax.numpy as jnp + from transformers.models.albert.modeling_flax_albert import ( + FlaxAlbertForMaskedLM, + FlaxAlbertForMultipleChoice, + FlaxAlbertForPreTraining, + FlaxAlbertForQuestionAnswering, + FlaxAlbertForSequenceClassification, + FlaxAlbertForTokenClassification, + FlaxAlbertModel, + ) + + +class FlaxAlbertModelTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_attention_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_choices=4, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_choices = num_choices + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + config = AlbertConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + ) + + return config, input_ids, token_type_ids, attention_mask + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, token_type_ids, attention_mask = config_and_inputs + inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask} + return config, inputs_dict + + +@require_flax +class FlaxAlbertModelTest(FlaxModelTesterMixin, unittest.TestCase): + + all_model_classes = ( + ( + FlaxAlbertModel, + FlaxAlbertForPreTraining, + FlaxAlbertForMaskedLM, + FlaxAlbertForMultipleChoice, + FlaxAlbertForQuestionAnswering, + FlaxAlbertForSequenceClassification, + FlaxAlbertForTokenClassification, + FlaxAlbertForQuestionAnswering, + ) + if is_flax_available() + else () + ) + + def setUp(self): + self.model_tester = FlaxAlbertModelTester(self) + + @slow + def test_model_from_pretrained(self): + for model_class_name in self.all_model_classes: + model = model_class_name.from_pretrained("albert-base-v2") + outputs = model(np.ones((1, 1))) + self.assertIsNotNone(outputs) + + +@require_flax +class FlaxAlbertModelIntegrationTest(unittest.TestCase): + @slow + def test_inference_no_head_absolute_embedding(self): + model = FlaxAlbertModel.from_pretrained("albert-base-v2") + input_ids = np.array([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]]) + attention_mask = np.array([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) + output = model(input_ids, attention_mask=attention_mask)[0] + expected_shape = (1, 11, 768) + self.assertEqual(output.shape, expected_shape) + expected_slice = np.array( + [[[-0.6513, 1.5035, -0.2766], [-0.6515, 1.5046, -0.2780], [-0.6512, 1.5049, -0.2784]]] + ) + + self.assertTrue(jnp.allclose(output[:, 1:4, 1:4], expected_slice, atol=1e-4))