From f748bd424213ca8e76e6ad9ffe2beece2ff2655e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 29 Apr 2021 12:04:51 +0200 Subject: [PATCH] [Flax] Add docstrings & model outputs (#11498) * add attentions & hidden states * add model outputs + docs * finish docs * finish tests * finish impl * del @ * finish * finish * correct test * apply sylvains suggestions * Update src/transformers/models/bert/modeling_flax_bert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * simplify more Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/file_utils.py | 162 +++++- src/transformers/modeling_flax_outputs.py | 239 ++++++++ src/transformers/modeling_flax_utils.py | 21 + .../models/bert/modeling_flax_bert.py | 522 ++++++++++++++++-- .../models/roberta/modeling_flax_roberta.py | 160 +++++- tests/test_modeling_common.py | 1 - tests/test_modeling_flax_common.py | 115 +++- 7 files changed, 1130 insertions(+), 90 deletions(-) create mode 100644 src/transformers/modeling_flax_outputs.py diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index ca0cddc9d5..93c032b722 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -794,6 +794,17 @@ PT_CAUSAL_LM_SAMPLE = r""" >>> logits = outputs.logits """ +PT_SAMPLE_DOCSTRINGS = { + "SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE, + "QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE, + "TokenClassification": PT_TOKEN_CLASSIFICATION_SAMPLE, + "MultipleChoice": PT_MULTIPLE_CHOICE_SAMPLE, + "MaskedLM": PT_MASKED_LM_SAMPLE, + "LMHead": PT_CAUSAL_LM_SAMPLE, + "BaseModel": PT_BASE_MODEL_SAMPLE, +} + + TF_TOKEN_CLASSIFICATION_SAMPLE = r""" Example:: @@ -915,30 +926,148 @@ TF_CAUSAL_LM_SAMPLE = r""" >>> logits = outputs.logits """ +TF_SAMPLE_DOCSTRINGS = { + "SequenceClassification": TF_SEQUENCE_CLASSIFICATION_SAMPLE, + "QuestionAnswering": TF_QUESTION_ANSWERING_SAMPLE, + "TokenClassification": TF_TOKEN_CLASSIFICATION_SAMPLE, + "MultipleChoice": TF_MULTIPLE_CHOICE_SAMPLE, + "MaskedLM": TF_MASKED_LM_SAMPLE, + "LMHead": TF_CAUSAL_LM_SAMPLE, + "BaseModel": TF_BASE_MODEL_SAMPLE, +} + + +FLAX_TOKEN_CLASSIFICATION_SAMPLE = r""" + Example:: + + >>> from transformers import {tokenizer_class}, {model_class} + + >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') + >>> model = {model_class}.from_pretrained('{checkpoint}') + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax') + + >>> outputs = model(**inputs) + >>> logits = outputs.logits +""" + +FLAX_QUESTION_ANSWERING_SAMPLE = r""" + Example:: + + >>> from transformers import {tokenizer_class}, {model_class} + + >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') + >>> model = {model_class}.from_pretrained('{checkpoint}') + + >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + >>> inputs = tokenizer(question, text, return_tensors='jax') + + >>> outputs = model(**inputs) + >>> start_scores = outputs.start_logits + >>> end_scores = outputs.end_logits +""" + +FLAX_SEQUENCE_CLASSIFICATION_SAMPLE = r""" + Example:: + + >>> from transformers import {tokenizer_class}, {model_class} + + >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') + >>> model = {model_class}.from_pretrained('{checkpoint}') + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax') + + >>> outputs = model(**inputs, labels=labels) + >>> logits = outputs.logits +""" + +FLAX_MASKED_LM_SAMPLE = r""" + Example:: + + >>> from transformers import {tokenizer_class}, {model_class} + + >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') + >>> model = {model_class}.from_pretrained('{checkpoint}') + + >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors='jax') + + >>> outputs = model(**inputs) + >>> logits = outputs.logits +""" + +FLAX_BASE_MODEL_SAMPLE = r""" + Example:: + + >>> from transformers import {tokenizer_class}, {model_class} + + >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') + >>> model = {model_class}.from_pretrained('{checkpoint}') + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax') + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state +""" + +FLAX_MULTIPLE_CHOICE_SAMPLE = r""" + Example:: + + >>> from transformers import {tokenizer_class}, {model_class} + + >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}') + >>> model = {model_class}.from_pretrained('{checkpoint}') + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> choice0 = "It is eaten with a fork and a knife." + >>> choice1 = "It is eaten while held in the hand." + + >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='jax', padding=True) + >>> outputs = model(**{{k: v[None, :] for k,v in encoding.items()}}) + + >>> logits = outputs.logits +""" + +FLAX_SAMPLE_DOCSTRINGS = { + "SequenceClassification": FLAX_SEQUENCE_CLASSIFICATION_SAMPLE, + "QuestionAnswering": FLAX_QUESTION_ANSWERING_SAMPLE, + "TokenClassification": FLAX_TOKEN_CLASSIFICATION_SAMPLE, + "MultipleChoice": FLAX_MULTIPLE_CHOICE_SAMPLE, + "MaskedLM": FLAX_MASKED_LM_SAMPLE, + "BaseModel": FLAX_BASE_MODEL_SAMPLE, +} + def add_code_sample_docstrings( - *docstr, tokenizer_class=None, checkpoint=None, output_type=None, config_class=None, mask=None + *docstr, tokenizer_class=None, checkpoint=None, output_type=None, config_class=None, mask=None, model_cls=None ): def docstring_decorator(fn): - model_class = fn.__qualname__.split(".")[0] - is_tf_class = model_class[:2] == "TF" + # model_class defaults to function's class if not specified otherwise + model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls + + if model_class[:2] == "TF": + sample_docstrings = TF_SAMPLE_DOCSTRINGS + elif model_class[:4] == "Flax": + sample_docstrings = FLAX_SAMPLE_DOCSTRINGS + else: + sample_docstrings = PT_SAMPLE_DOCSTRINGS + doc_kwargs = dict(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint) if "SequenceClassification" in model_class: - code_sample = TF_SEQUENCE_CLASSIFICATION_SAMPLE if is_tf_class else PT_SEQUENCE_CLASSIFICATION_SAMPLE + code_sample = sample_docstrings["SequenceClassification"] elif "QuestionAnswering" in model_class: - code_sample = TF_QUESTION_ANSWERING_SAMPLE if is_tf_class else PT_QUESTION_ANSWERING_SAMPLE + code_sample = sample_docstrings["QuestionAnswering"] elif "TokenClassification" in model_class: - code_sample = TF_TOKEN_CLASSIFICATION_SAMPLE if is_tf_class else PT_TOKEN_CLASSIFICATION_SAMPLE + code_sample = sample_docstrings["TokenClassification"] elif "MultipleChoice" in model_class: - code_sample = TF_MULTIPLE_CHOICE_SAMPLE if is_tf_class else PT_MULTIPLE_CHOICE_SAMPLE + code_sample = sample_docstrings["MultipleChoice"] elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]: doc_kwargs["mask"] = "[MASK]" if mask is None else mask - code_sample = TF_MASKED_LM_SAMPLE if is_tf_class else PT_MASKED_LM_SAMPLE + code_sample = sample_docstrings["MaskedLM"] elif "LMHead" in model_class or "CausalLM" in model_class: - code_sample = TF_CAUSAL_LM_SAMPLE if is_tf_class else PT_CAUSAL_LM_SAMPLE + code_sample = sample_docstrings["LMHead"] elif "Model" in model_class or "Encoder" in model_class: - code_sample = TF_BASE_MODEL_SAMPLE if is_tf_class else PT_BASE_MODEL_SAMPLE + code_sample = sample_docstrings["BaseModel"] else: raise ValueError(f"Docstring can't be built for model {model_class}") @@ -1462,7 +1591,10 @@ def tf_required(func): def is_tensor(x): - """Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor` or :obj:`np.ndarray`.""" + """ + Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor`, obj:`jaxlib.xla_extension.DeviceArray` or + :obj:`np.ndarray`. + """ if is_torch_available(): import torch @@ -1473,6 +1605,14 @@ def is_tensor(x): if isinstance(x, tf.Tensor): return True + + if is_flax_available(): + import jaxlib.xla_extension as jax_xla + from jax.interpreters.partial_eval import DynamicJaxprTracer + + if isinstance(x, (jax_xla.DeviceArray, DynamicJaxprTracer)): + return True + return isinstance(x, np.ndarray) diff --git a/src/transformers/modeling_flax_outputs.py b/src/transformers/modeling_flax_outputs.py new file mode 100644 index 0000000000..5f96307ed3 --- /dev/null +++ b/src/transformers/modeling_flax_outputs.py @@ -0,0 +1,239 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple + +import jaxlib.xla_extension as jax_xla + +from .file_utils import ModelOutput + + +@dataclass +class FlaxBaseModelOutput(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each + layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: jax_xla.DeviceArray = None + hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None + attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + + +@dataclass +class FlaxBaseModelOutputWithPooling(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) further processed by a + Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence + prediction (classification) objective during pretraining. + hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each + layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: jax_xla.DeviceArray = None + pooler_output: jax_xla.DeviceArray = None + hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None + attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + + +@dataclass +class FlaxMaskedLMOutput(ModelOutput): + """ + Base class for masked language models outputs. + + Args: + logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each + layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + logits: jax_xla.DeviceArray = None + hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None + attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + + +@dataclass +class FlaxNextSentencePredictorOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each + layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + logits: jax_xla.DeviceArray = None + hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None + attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + + +@dataclass +class FlaxSequenceClassifierOutput(ModelOutput): + """ + Base class for outputs of sentence classification models. + + Args: + logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each + layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + logits: jax_xla.DeviceArray = None + hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None + attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + + +@dataclass +class FlaxMultipleChoiceModelOutput(ModelOutput): + """ + Base class for outputs of multiple choice models. + + Args: + logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, num_choices)`): + `num_choices` is the second dimension of the input tensors. (see `input_ids` above). + + Classification scores (before SoftMax). + hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each + layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + logits: jax_xla.DeviceArray = None + hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None + attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + + +@dataclass +class FlaxTokenClassifierOutput(ModelOutput): + """ + Base class for outputs of token classification models. + + Args: + logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.num_labels)`): + Classification scores (before SoftMax). + hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each + layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + logits: jax_xla.DeviceArray = None + hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None + attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + + +@dataclass +class FlaxQuestionAnsweringModelOutput(ModelOutput): + """ + Base class for outputs of question answering models. + + Args: + start_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`): + Span-end scores (before SoftMax). + hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each + layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + start_logits: jax_xla.DeviceArray = None + end_logits: jax_xla.DeviceArray = None + hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None + attentions: Optional[Tuple[jax_xla.DeviceArray]] = None diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index b32acd0f7d..51e65f37b2 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -32,12 +32,14 @@ from .file_utils import ( FLAX_WEIGHTS_NAME, WEIGHTS_NAME, PushToHubMixin, + add_code_sample_docstrings, add_start_docstrings_to_model_forward, cached_path, copy_func, hf_bucket_url, is_offline_mode, is_remote_url, + replace_return_docstrings, ) from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict from .utils import logging @@ -432,3 +434,22 @@ def overwrite_call_docstring(model_class, docstring): model_class.__call__.__doc__ = None # set correct docstring model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__) + + +def append_call_sample_docstring(model_class, tokenizer_class, checkpoint, output_type, config_class, mask=None): + model_class.__call__ = copy_func(model_class.__call__) + model_class.__call__ = add_code_sample_docstrings( + tokenizer_class=tokenizer_class, + checkpoint=checkpoint, + output_type=output_type, + config_class=config_class, + model_cls=model_class.__name__, + )(model_class.__call__) + + +def append_replace_return_docstrings(model_class, output_type, config_class): + model_class.__call__ = copy_func(model_class.__call__) + model_class.__call__ = replace_return_docstrings( + output_type=output_type, + config_class=config_class, + )(model_class.__call__) diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index 56a167ee85..64b95d2837 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -13,30 +13,79 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Tuple +from dataclasses import dataclass +from typing import Callable, Optional, Tuple import numpy as np import flax.linen as nn import jax import jax.numpy as jnp +import jaxlib.xla_extension as jax_xla from flax.core.frozen_dict import FrozenDict from flax.linen import dot_product_attention from jax import lax from jax.random import PRNGKey -from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward -from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, overwrite_call_docstring +from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPooling, + FlaxMaskedLMOutput, + FlaxMultipleChoiceModelOutput, + FlaxNextSentencePredictorOutput, + FlaxQuestionAnsweringModelOutput, + FlaxSequenceClassifierOutput, + FlaxTokenClassifierOutput, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, +) from ...utils import logging from .configuration_bert import BertConfig logger = logging.get_logger(__name__) +_CHECKPOINT_FOR_DOC = "bert-base-uncased" _CONFIG_FOR_DOC = "BertConfig" _TOKENIZER_FOR_DOC = "BertTokenizer" +@dataclass +class FlaxBertForPreTrainingOutput(ModelOutput): + """ + Output type of :class:`~transformers.BertForPreTraining`. + + Args: + prediction_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each + layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + prediction_logits: jax_xla.DeviceArray = None + seq_relationship_logits: jax_xla.DeviceArray = None + hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None + attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + + BERT_START_DOCSTRING = r""" This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the @@ -166,7 +215,7 @@ class FlaxBertSelfAttention(nn.Module): kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), ) - def __call__(self, hidden_states, attention_mask, deterministic=True): + def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False): head_dim = self.config.hidden_size // self.config.num_attention_heads query_states = self.query(hidden_states).reshape( @@ -208,7 +257,12 @@ class FlaxBertSelfAttention(nn.Module): precision=None, ) - return attn_output.reshape(attn_output.shape[:2] + (-1,)) + outputs = (attn_output.reshape(attn_output.shape[:2] + (-1,)),) + + # TODO: at the moment it's not possible to retrieve attn_weights from + # dot_product_attention, but should be in the future -> add functionality then + + return outputs class FlaxBertSelfOutput(nn.Module): @@ -239,13 +293,22 @@ class FlaxBertAttention(nn.Module): self.self = FlaxBertSelfAttention(self.config, dtype=self.dtype) self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype) - def __call__(self, hidden_states, attention_mask, deterministic=True): + def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False): # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) - attn_output = self.self(hidden_states, attention_mask, deterministic=deterministic) + attn_outputs = self.self( + hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions + ) + attn_output = attn_outputs[0] hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) - return hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += attn_outputs[1] + + return outputs class FlaxBertIntermediate(nn.Module): @@ -295,11 +358,20 @@ class FlaxBertLayer(nn.Module): self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype) self.output = FlaxBertOutput(self.config, dtype=self.dtype) - def __call__(self, hidden_states, attention_mask, deterministic: bool = True): - attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic) + def __call__(self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False): + attention_outputs = self.attention( + hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions + ) + attention_output = attention_outputs[0] + hidden_states = self.intermediate(attention_output) hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) - return hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attention_outputs[1],) + return outputs class FlaxBertLayerCollection(nn.Module): @@ -311,10 +383,40 @@ class FlaxBertLayerCollection(nn.Module): FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) ] - def __call__(self, hidden_states, attention_mask, deterministic: bool = True): + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, layer in enumerate(self.layers): - hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic) - return hidden_states + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer(hidden_states, attention_mask, deterministic=deterministic) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states,) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) class FlaxBertEncoder(nn.Module): @@ -324,8 +426,23 @@ class FlaxBertEncoder(nn.Module): def setup(self): self.layer = FlaxBertLayerCollection(self.config, dtype=self.dtype) - def __call__(self, hidden_states, attention_mask, deterministic: bool = True): - return self.layer(hidden_states, attention_mask, deterministic=deterministic) + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return self.layer( + hidden_states, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) class FlaxBertPooler(nn.Module): @@ -456,7 +573,21 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): params: dict = None, dropout_rng: PRNGKey = None, train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if output_attentions: + raise NotImplementedError( + "Currently attention scores cannot be returned. Please set `output_attentions` to False for now." + ) + # init input tensors if not passed if token_type_ids is None: token_type_ids = jnp.ones_like(input_ids) @@ -479,6 +610,9 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): jnp.array(token_type_ids, dtype="i4"), jnp.array(position_ids, dtype="i4"), not train, + output_attentions, + output_hidden_states, + return_dict, rngs=rngs, ) @@ -493,17 +627,43 @@ class FlaxBertModule(nn.Module): self.encoder = FlaxBertEncoder(self.config, dtype=self.dtype) self.pooler = FlaxBertPooler(self.config, dtype=self.dtype) - def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True): + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): hidden_states = self.embeddings( input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic ) - hidden_states = self.encoder(hidden_states, attention_mask, deterministic=deterministic) + outputs = self.encoder( + hidden_states, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + pooled = self.pooler(hidden_states) if self.add_pooling_layer else None - if not self.add_pooling_layer: - return hidden_states + if not return_dict: + # if pooled is None, don't return it + if pooled is None: + return (hidden_states,) + outputs[1:] + return (hidden_states, pooled) + outputs[1:] - pooled = self.pooler(hidden_states) - return hidden_states, pooled + return FlaxBaseModelOutputWithPooling( + last_hidden_state=hidden_states, + pooler_output=pooled, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) @add_start_docstrings( @@ -514,6 +674,11 @@ class FlaxBertModel(FlaxBertPreTrainedModel): module_class = FlaxBertModule +append_call_sample_docstring( + FlaxBertModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC +) + + class FlaxBertForPreTrainingModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 @@ -523,11 +688,27 @@ class FlaxBertForPreTrainingModule(nn.Module): self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype) def __call__( - self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, ): + # Model - hidden_states, pooled_output = self.bert( - input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic + outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, ) if self.config.tie_word_embeddings: @@ -535,11 +716,22 @@ class FlaxBertForPreTrainingModule(nn.Module): else: shared_embedding = None + hidden_states = outputs[0] + pooled_output = outputs[1] + prediction_scores, seq_relationship_score = self.cls( hidden_states, pooled_output, shared_embedding=shared_embedding ) - return (prediction_scores, seq_relationship_score) + if not return_dict: + return (prediction_scores, seq_relationship_score) + outputs[2:] + + return FlaxBertForPreTrainingOutput( + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) @add_start_docstrings( @@ -553,6 +745,32 @@ class FlaxBertForPreTraining(FlaxBertPreTrainedModel): module_class = FlaxBertForPreTrainingModule +FLAX_BERT_FOR_PRETRAINING_DOCSTRING = """ + Returns: + + Example:: + + >>> from transformers import BertTokenizer, FlaxBertForPreTraining + + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + >>> model = FlaxBertForPreTraining.from_pretrained('bert-base-uncased') + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.prediction_logits + >>> seq_relationship_logits = outputs.seq_relationship_logits +""" + +overwrite_call_docstring( + FlaxBertForPreTraining, + BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BERT_FOR_PRETRAINING_DOCSTRING, +) +append_replace_return_docstrings( + FlaxBertForPreTraining, output_type=FlaxBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC +) + + class FlaxBertForMaskedLMModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 @@ -562,11 +780,29 @@ class FlaxBertForMaskedLMModule(nn.Module): self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype) def __call__( - self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, ): # Model - hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic) + outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] if self.config.tie_word_embeddings: shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] else: @@ -575,7 +811,14 @@ class FlaxBertForMaskedLMModule(nn.Module): # Compute the prediction scores logits = self.cls(hidden_states, shared_embedding=shared_embedding) - return (logits,) + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxMaskedLMOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING) @@ -583,6 +826,11 @@ class FlaxBertForMaskedLM(FlaxBertPreTrainedModel): module_class = FlaxBertForMaskedLMModule +append_call_sample_docstring( + FlaxBertForMaskedLM, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC +) + + class FlaxBertForNextSentencePredictionModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 @@ -592,15 +840,41 @@ class FlaxBertForNextSentencePredictionModule(nn.Module): self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype) def __call__( - self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, ): + return_dict = return_dict if return_dict is not None else self.config.return_dict + # Model - _, pooled_output = self.bert( - input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic + outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, ) + pooled_output = outputs[1] seq_relationship_scores = self.cls(pooled_output) - return (seq_relationship_scores,) + + if not return_dict: + return (seq_relationship_scores,) + outputs[2:] + + return FlaxNextSentencePredictorOutput( + logits=seq_relationship_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) @add_start_docstrings( @@ -611,6 +885,35 @@ class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel): module_class = FlaxBertForNextSentencePredictionModule +FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING = """ + Returns: + + Example:: + + >>> from transformers import BertTokenizer, FlaxBertForNextSentencePrediction + + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + >>> model = FlaxBertForNextSentencePrediction.from_pretrained('bert-base-uncased') + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors='jax') + + >>> outputs = model(**encoding) + >>> logits = outputs.logits + >>> assert logits[0, 0] < logits[0, 1] # next sentence was random +""" + + +overwrite_call_docstring( + FlaxBertForNextSentencePrediction, + BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING, +) +append_replace_return_docstrings( + FlaxBertForNextSentencePrediction, output_type=FlaxNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC +) + + class FlaxBertForSequenceClassificationModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 @@ -624,17 +927,40 @@ class FlaxBertForSequenceClassificationModule(nn.Module): ) def __call__( - self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, ): # Model - _, pooled_output = self.bert( - input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic + outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, ) + pooled_output = outputs[1] pooled_output = self.dropout(pooled_output, deterministic=deterministic) logits = self.classifier(pooled_output) - return (logits,) + if not return_dict: + return (logits,) + outputs[2:] + + return FlaxSequenceClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) @add_start_docstrings( @@ -648,6 +974,15 @@ class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel): module_class = FlaxBertForSequenceClassificationModule +append_call_sample_docstring( + FlaxBertForSequenceClassification, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxSequenceClassifierOutput, + _CONFIG_FOR_DOC, +) + + class FlaxBertForMultipleChoiceModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 @@ -658,7 +993,15 @@ class FlaxBertForMultipleChoiceModule(nn.Module): self.classifier = nn.Dense(1, dtype=self.dtype) def __call__( - self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, ): num_choices = input_ids.shape[1] input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None @@ -667,16 +1010,31 @@ class FlaxBertForMultipleChoiceModule(nn.Module): position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None # Model - _, pooled_output = self.bert( - input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic + outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, ) + pooled_output = outputs[1] pooled_output = self.dropout(pooled_output, deterministic=deterministic) logits = self.classifier(pooled_output) reshaped_logits = logits.reshape(-1, num_choices) - return (reshaped_logits,) + if not return_dict: + return (reshaped_logits,) + outputs[2:] + + return FlaxMultipleChoiceModelOutput( + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) @add_start_docstrings( @@ -690,10 +1048,12 @@ class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel): module_class = FlaxBertForMultipleChoiceModule -# adapt docstring slightly for FlaxBertForMultipleChoice overwrite_call_docstring( FlaxBertForMultipleChoice, BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") ) +append_call_sample_docstring( + FlaxBertForMultipleChoice, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxMultipleChoiceModelOutput, _CONFIG_FOR_DOC +) class FlaxBertForTokenClassificationModule(nn.Module): @@ -706,15 +1066,40 @@ class FlaxBertForTokenClassificationModule(nn.Module): self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) def __call__( - self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, ): # Model - hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic) + outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] hidden_states = self.dropout(hidden_states, deterministic=deterministic) logits = self.classifier(hidden_states) - return (logits,) + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxTokenClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) @add_start_docstrings( @@ -728,6 +1113,11 @@ class FlaxBertForTokenClassification(FlaxBertPreTrainedModel): module_class = FlaxBertForTokenClassificationModule +append_call_sample_docstring( + FlaxBertForTokenClassification, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxTokenClassifierOutput, _CONFIG_FOR_DOC +) + + class FlaxBertForQuestionAnsweringModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 @@ -737,17 +1127,44 @@ class FlaxBertForQuestionAnsweringModule(nn.Module): self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) def __call__( - self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, ): # Model - hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic) + outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] logits = self.qa_outputs(hidden_states) start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) - return (start_logits, end_logits) + if not return_dict: + return (start_logits, end_logits) + outputs[1:] + + return FlaxQuestionAnsweringModelOutput( + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) @add_start_docstrings( @@ -759,3 +1176,12 @@ class FlaxBertForQuestionAnsweringModule(nn.Module): ) class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel): module_class = FlaxBertForQuestionAnsweringModule + + +append_call_sample_docstring( + FlaxBertForQuestionAnswering, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxQuestionAnsweringModelOutput, + _CONFIG_FOR_DOC, +) diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py index ef0c46660f..5c1fd0706f 100644 --- a/src/transformers/models/roberta/modeling_flax_roberta.py +++ b/src/transformers/models/roberta/modeling_flax_roberta.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Optional, Tuple import flax.linen as nn import jax @@ -23,13 +23,15 @@ from jax import lax from jax.random import PRNGKey from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward -from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel +from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling +from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring from ...utils import logging from .configuration_roberta import RobertaConfig logger = logging.get_logger(__name__) +_CHECKPOINT_FOR_DOC = "roberta-base" _CONFIG_FOR_DOC = "RobertaConfig" _TOKENIZER_FOR_DOC = "RobertaTokenizer" @@ -181,7 +183,7 @@ class FlaxRobertaSelfAttention(nn.Module): kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), ) - def __call__(self, hidden_states, attention_mask, deterministic=True): + def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False): head_dim = self.config.hidden_size // self.config.num_attention_heads query_states = self.query(hidden_states).reshape( @@ -223,7 +225,12 @@ class FlaxRobertaSelfAttention(nn.Module): precision=None, ) - return attn_output.reshape(attn_output.shape[:2] + (-1,)) + outputs = (attn_output.reshape(attn_output.shape[:2] + (-1,)),) + + # TODO: at the moment it's not possible to retrieve attn_weights from + # dot_product_attention, but should be in the future -> add functionality then + + return outputs # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Roberta @@ -256,13 +263,22 @@ class FlaxRobertaAttention(nn.Module): self.self = FlaxRobertaSelfAttention(self.config, dtype=self.dtype) self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype) - def __call__(self, hidden_states, attention_mask, deterministic=True): + def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False): # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) - attn_output = self.self(hidden_states, attention_mask, deterministic=deterministic) + attn_outputs = self.self( + hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions + ) + attn_output = attn_outputs[0] hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) - return hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += attn_outputs[1] + + return outputs # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta @@ -315,11 +331,20 @@ class FlaxRobertaLayer(nn.Module): self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype) self.output = FlaxRobertaOutput(self.config, dtype=self.dtype) - def __call__(self, hidden_states, attention_mask, deterministic: bool = True): - attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic) + def __call__(self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False): + attention_outputs = self.attention( + hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions + ) + attention_output = attention_outputs[0] + hidden_states = self.intermediate(attention_output) hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) - return hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attention_outputs[1],) + return outputs # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta @@ -332,10 +357,40 @@ class FlaxRobertaLayerCollection(nn.Module): FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) ] - def __call__(self, hidden_states, attention_mask, deterministic: bool = True): + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, layer in enumerate(self.layers): - hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic) - return hidden_states + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer(hidden_states, attention_mask, deterministic=deterministic) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states,) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta @@ -346,8 +401,23 @@ class FlaxRobertaEncoder(nn.Module): def setup(self): self.layer = FlaxRobertaLayerCollection(self.config, dtype=self.dtype) - def __call__(self, hidden_states, attention_mask, deterministic: bool = True): - return self.layer(hidden_states, attention_mask, deterministic=deterministic) + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return self.layer( + hidden_states, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta @@ -412,7 +482,21 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): params: dict = None, dropout_rng: PRNGKey = None, train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if output_attentions: + raise NotImplementedError( + "Currently attention scores cannot be returned." "Please set `output_attentions` to False for now." + ) + # init input tensors if not passed if token_type_ids is None: token_type_ids = jnp.ones_like(input_ids) @@ -435,6 +519,9 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): jnp.array(token_type_ids, dtype="i4"), jnp.array(position_ids, dtype="i4"), not train, + output_attentions, + output_hidden_states, + return_dict, rngs=rngs, ) @@ -450,17 +537,43 @@ class FlaxRobertaModule(nn.Module): self.encoder = FlaxRobertaEncoder(self.config, dtype=self.dtype) self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype) - def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True): + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): hidden_states = self.embeddings( input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic ) - hidden_states = self.encoder(hidden_states, attention_mask, deterministic=deterministic) + outputs = self.encoder( + hidden_states, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + pooled = self.pooler(hidden_states) if self.add_pooling_layer else None - if not self.add_pooling_layer: - return hidden_states + if not return_dict: + # if pooled is None, don't return it + if pooled is None: + return (hidden_states,) + outputs[1:] + return (hidden_states, pooled) + outputs[1:] - pooled = self.pooler(hidden_states) - return hidden_states, pooled + return FlaxBaseModelOutputWithPooling( + last_hidden_state=hidden_states, + pooler_output=pooled, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) @add_start_docstrings( @@ -469,3 +582,8 @@ class FlaxRobertaModule(nn.Module): ) class FlaxRobertaModel(FlaxRobertaPreTrainedModel): module_class = FlaxRobertaModule + + +append_call_sample_docstring( + FlaxRobertaModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC +) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index d93faa1f6c..d193a9e7a4 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -998,7 +998,6 @@ class ModelTesterMixin: # self.assertTrue(check_same_values(model.transformer.wte, model.lm_head)) def test_model_outputs_equivalence(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() def set_nan_tensor_to_zero(t): diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 8d5ca111fd..dddac75236 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -13,8 +13,10 @@ # limitations under the License. import copy +import inspect import random import tempfile +from typing import List, Tuple import numpy as np @@ -28,6 +30,7 @@ if is_flax_available(): import jax import jax.numpy as jnp + import jaxlib.xla_extension as jax_xla from transformers.modeling_flax_pytorch_utils import ( convert_pytorch_state_dict_to_flax, load_flax_weights_in_pytorch_model, @@ -77,6 +80,7 @@ class FlaxModelTesterMixin: inputs_dict = { k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1])) for k, v in inputs_dict.items() + if isinstance(v, (jax_xla.DeviceArray, np.ndarray)) } return inputs_dict @@ -85,6 +89,41 @@ class FlaxModelTesterMixin: diff = np.abs((a - b)).max() self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") + def test_model_outputs_equivalence(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def set_nan_tensor_to_zero(t): + t[t != t] = 0 + return t + + def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): + tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) + dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assert_almost_equals( + set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), 1e-5 + ) + + recursive_check(tuple_output, dict_output) + + for model_class in self.all_model_classes: + model = model_class(config) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + @is_pt_flax_cross_test def test_equivalence_pt_to_flax(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -108,7 +147,7 @@ class FlaxModelTesterMixin: with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() - fx_outputs = fx_model(**prepared_inputs_dict) + fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs, pt_outputs): self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3) @@ -117,7 +156,7 @@ class FlaxModelTesterMixin: pt_model.save_pretrained(tmpdirname) fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) - fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict) + fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple() self.assertEqual( len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" ) @@ -149,7 +188,7 @@ class FlaxModelTesterMixin: with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() - fx_outputs = fx_model(**prepared_inputs_dict) + fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs, pt_outputs): self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3) @@ -171,17 +210,20 @@ class FlaxModelTesterMixin: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: + if model_class.__name__ != "FlaxBertModel": + continue + with self.subTest(model_class.__name__): model = model_class(config) prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - outputs = model(**prepared_inputs_dict) + outputs = model(**prepared_inputs_dict).to_tuple() with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_loaded = model_class.from_pretrained(tmpdirname) - outputs_loaded = model_loaded(**prepared_inputs_dict) + outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple() for output_loaded, output in zip(outputs_loaded, outputs): self.assert_almost_equals(output_loaded, output, 1e-3) @@ -195,19 +237,47 @@ class FlaxModelTesterMixin: @jax.jit def model_jitted(input_ids, attention_mask=None, token_type_ids=None): - return model(input_ids, attention_mask, token_type_ids) + return model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + ).to_tuple() + + with self.subTest("JIT Enabled"): + jitted_outputs = model_jitted(**prepared_inputs_dict) with self.subTest("JIT Disabled"): with jax.disable_jit(): outputs = model_jitted(**prepared_inputs_dict) - with self.subTest("JIT Enabled"): - jitted_outputs = model_jitted(**prepared_inputs_dict) - self.assertEqual(len(outputs), len(jitted_outputs)) for jitted_output, output in zip(jitted_outputs, outputs): self.assertEqual(jitted_output.shape, output.shape) + @jax.jit + def model_jitted_return_dict(input_ids, attention_mask=None, token_type_ids=None): + return model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + ) + + # jitted function cannot return OrderedDict + with self.assertRaises(TypeError): + model_jitted_return_dict(**prepared_inputs_dict) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.__call__) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["input_ids", "attention_mask"] + self.assertListEqual(arg_names[:2], expected_arg_names) + def test_naming_convention(self): for model_class in self.all_model_classes: model_class_name = model_class.__name__ @@ -218,3 +288,30 @@ class FlaxModelTesterMixin: module_cls = getattr(bert_modeling_flax_module, module_class_name) self.assertIsNotNone(module_cls) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + hidden_states = outputs.hidden_states + + self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1) + seq_length = self.model_tester.seq_length + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [seq_length, self.model_tester.hidden_size], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class)