[Flax] Add docstrings & model outputs (#11498)
* add attentions & hidden states * add model outputs + docs * finish docs * finish tests * finish impl * del @ * finish * finish * correct test * apply sylvains suggestions * Update src/transformers/models/bert/modeling_flax_bert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * simplify more Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
3f6add8bab
commit
f748bd4242
@@ -794,6 +794,17 @@ PT_CAUSAL_LM_SAMPLE = r"""
|
|||||||
>>> logits = outputs.logits
|
>>> logits = outputs.logits
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
PT_SAMPLE_DOCSTRINGS = {
|
||||||
|
"SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE,
|
||||||
|
"QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE,
|
||||||
|
"TokenClassification": PT_TOKEN_CLASSIFICATION_SAMPLE,
|
||||||
|
"MultipleChoice": PT_MULTIPLE_CHOICE_SAMPLE,
|
||||||
|
"MaskedLM": PT_MASKED_LM_SAMPLE,
|
||||||
|
"LMHead": PT_CAUSAL_LM_SAMPLE,
|
||||||
|
"BaseModel": PT_BASE_MODEL_SAMPLE,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
TF_TOKEN_CLASSIFICATION_SAMPLE = r"""
|
TF_TOKEN_CLASSIFICATION_SAMPLE = r"""
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
@@ -915,30 +926,148 @@ TF_CAUSAL_LM_SAMPLE = r"""
|
|||||||
>>> logits = outputs.logits
|
>>> logits = outputs.logits
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
TF_SAMPLE_DOCSTRINGS = {
|
||||||
|
"SequenceClassification": TF_SEQUENCE_CLASSIFICATION_SAMPLE,
|
||||||
|
"QuestionAnswering": TF_QUESTION_ANSWERING_SAMPLE,
|
||||||
|
"TokenClassification": TF_TOKEN_CLASSIFICATION_SAMPLE,
|
||||||
|
"MultipleChoice": TF_MULTIPLE_CHOICE_SAMPLE,
|
||||||
|
"MaskedLM": TF_MASKED_LM_SAMPLE,
|
||||||
|
"LMHead": TF_CAUSAL_LM_SAMPLE,
|
||||||
|
"BaseModel": TF_BASE_MODEL_SAMPLE,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
FLAX_TOKEN_CLASSIFICATION_SAMPLE = r"""
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import {tokenizer_class}, {model_class}
|
||||||
|
|
||||||
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
||||||
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||||
|
|
||||||
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax')
|
||||||
|
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
>>> logits = outputs.logits
|
||||||
|
"""
|
||||||
|
|
||||||
|
FLAX_QUESTION_ANSWERING_SAMPLE = r"""
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import {tokenizer_class}, {model_class}
|
||||||
|
|
||||||
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
||||||
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||||
|
|
||||||
|
>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
||||||
|
>>> inputs = tokenizer(question, text, return_tensors='jax')
|
||||||
|
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
>>> start_scores = outputs.start_logits
|
||||||
|
>>> end_scores = outputs.end_logits
|
||||||
|
"""
|
||||||
|
|
||||||
|
FLAX_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import {tokenizer_class}, {model_class}
|
||||||
|
|
||||||
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
||||||
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||||
|
|
||||||
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax')
|
||||||
|
|
||||||
|
>>> outputs = model(**inputs, labels=labels)
|
||||||
|
>>> logits = outputs.logits
|
||||||
|
"""
|
||||||
|
|
||||||
|
FLAX_MASKED_LM_SAMPLE = r"""
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import {tokenizer_class}, {model_class}
|
||||||
|
|
||||||
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
||||||
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||||
|
|
||||||
|
>>> inputs = tokenizer("The capital of France is {mask}.", return_tensors='jax')
|
||||||
|
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
>>> logits = outputs.logits
|
||||||
|
"""
|
||||||
|
|
||||||
|
FLAX_BASE_MODEL_SAMPLE = r"""
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import {tokenizer_class}, {model_class}
|
||||||
|
|
||||||
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
||||||
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||||
|
|
||||||
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax')
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
|
||||||
|
>>> last_hidden_states = outputs.last_hidden_state
|
||||||
|
"""
|
||||||
|
|
||||||
|
FLAX_MULTIPLE_CHOICE_SAMPLE = r"""
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import {tokenizer_class}, {model_class}
|
||||||
|
|
||||||
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
||||||
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||||
|
|
||||||
|
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||||
|
>>> choice0 = "It is eaten with a fork and a knife."
|
||||||
|
>>> choice1 = "It is eaten while held in the hand."
|
||||||
|
|
||||||
|
>>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='jax', padding=True)
|
||||||
|
>>> outputs = model(**{{k: v[None, :] for k,v in encoding.items()}})
|
||||||
|
|
||||||
|
>>> logits = outputs.logits
|
||||||
|
"""
|
||||||
|
|
||||||
|
FLAX_SAMPLE_DOCSTRINGS = {
|
||||||
|
"SequenceClassification": FLAX_SEQUENCE_CLASSIFICATION_SAMPLE,
|
||||||
|
"QuestionAnswering": FLAX_QUESTION_ANSWERING_SAMPLE,
|
||||||
|
"TokenClassification": FLAX_TOKEN_CLASSIFICATION_SAMPLE,
|
||||||
|
"MultipleChoice": FLAX_MULTIPLE_CHOICE_SAMPLE,
|
||||||
|
"MaskedLM": FLAX_MASKED_LM_SAMPLE,
|
||||||
|
"BaseModel": FLAX_BASE_MODEL_SAMPLE,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def add_code_sample_docstrings(
|
def add_code_sample_docstrings(
|
||||||
*docstr, tokenizer_class=None, checkpoint=None, output_type=None, config_class=None, mask=None
|
*docstr, tokenizer_class=None, checkpoint=None, output_type=None, config_class=None, mask=None, model_cls=None
|
||||||
):
|
):
|
||||||
def docstring_decorator(fn):
|
def docstring_decorator(fn):
|
||||||
model_class = fn.__qualname__.split(".")[0]
|
# model_class defaults to function's class if not specified otherwise
|
||||||
is_tf_class = model_class[:2] == "TF"
|
model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls
|
||||||
|
|
||||||
|
if model_class[:2] == "TF":
|
||||||
|
sample_docstrings = TF_SAMPLE_DOCSTRINGS
|
||||||
|
elif model_class[:4] == "Flax":
|
||||||
|
sample_docstrings = FLAX_SAMPLE_DOCSTRINGS
|
||||||
|
else:
|
||||||
|
sample_docstrings = PT_SAMPLE_DOCSTRINGS
|
||||||
|
|
||||||
doc_kwargs = dict(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint)
|
doc_kwargs = dict(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint)
|
||||||
|
|
||||||
if "SequenceClassification" in model_class:
|
if "SequenceClassification" in model_class:
|
||||||
code_sample = TF_SEQUENCE_CLASSIFICATION_SAMPLE if is_tf_class else PT_SEQUENCE_CLASSIFICATION_SAMPLE
|
code_sample = sample_docstrings["SequenceClassification"]
|
||||||
elif "QuestionAnswering" in model_class:
|
elif "QuestionAnswering" in model_class:
|
||||||
code_sample = TF_QUESTION_ANSWERING_SAMPLE if is_tf_class else PT_QUESTION_ANSWERING_SAMPLE
|
code_sample = sample_docstrings["QuestionAnswering"]
|
||||||
elif "TokenClassification" in model_class:
|
elif "TokenClassification" in model_class:
|
||||||
code_sample = TF_TOKEN_CLASSIFICATION_SAMPLE if is_tf_class else PT_TOKEN_CLASSIFICATION_SAMPLE
|
code_sample = sample_docstrings["TokenClassification"]
|
||||||
elif "MultipleChoice" in model_class:
|
elif "MultipleChoice" in model_class:
|
||||||
code_sample = TF_MULTIPLE_CHOICE_SAMPLE if is_tf_class else PT_MULTIPLE_CHOICE_SAMPLE
|
code_sample = sample_docstrings["MultipleChoice"]
|
||||||
elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]:
|
elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]:
|
||||||
doc_kwargs["mask"] = "[MASK]" if mask is None else mask
|
doc_kwargs["mask"] = "[MASK]" if mask is None else mask
|
||||||
code_sample = TF_MASKED_LM_SAMPLE if is_tf_class else PT_MASKED_LM_SAMPLE
|
code_sample = sample_docstrings["MaskedLM"]
|
||||||
elif "LMHead" in model_class or "CausalLM" in model_class:
|
elif "LMHead" in model_class or "CausalLM" in model_class:
|
||||||
code_sample = TF_CAUSAL_LM_SAMPLE if is_tf_class else PT_CAUSAL_LM_SAMPLE
|
code_sample = sample_docstrings["LMHead"]
|
||||||
elif "Model" in model_class or "Encoder" in model_class:
|
elif "Model" in model_class or "Encoder" in model_class:
|
||||||
code_sample = TF_BASE_MODEL_SAMPLE if is_tf_class else PT_BASE_MODEL_SAMPLE
|
code_sample = sample_docstrings["BaseModel"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Docstring can't be built for model {model_class}")
|
raise ValueError(f"Docstring can't be built for model {model_class}")
|
||||||
|
|
||||||
@@ -1462,7 +1591,10 @@ def tf_required(func):
|
|||||||
|
|
||||||
|
|
||||||
def is_tensor(x):
|
def is_tensor(x):
|
||||||
"""Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor` or :obj:`np.ndarray`."""
|
"""
|
||||||
|
Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor`, obj:`jaxlib.xla_extension.DeviceArray` or
|
||||||
|
:obj:`np.ndarray`.
|
||||||
|
"""
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -1473,6 +1605,14 @@ def is_tensor(x):
|
|||||||
|
|
||||||
if isinstance(x, tf.Tensor):
|
if isinstance(x, tf.Tensor):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
import jaxlib.xla_extension as jax_xla
|
||||||
|
from jax.interpreters.partial_eval import DynamicJaxprTracer
|
||||||
|
|
||||||
|
if isinstance(x, (jax_xla.DeviceArray, DynamicJaxprTracer)):
|
||||||
|
return True
|
||||||
|
|
||||||
return isinstance(x, np.ndarray)
|
return isinstance(x, np.ndarray)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
239
src/transformers/modeling_flax_outputs.py
Normal file
239
src/transformers/modeling_flax_outputs.py
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import jaxlib.xla_extension as jax_xla
|
||||||
|
|
||||||
|
from .file_utils import ModelOutput
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlaxBaseModelOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for model's outputs, with potential hidden states and attentions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the model.
|
||||||
|
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||||
|
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
last_hidden_state: jax_xla.DeviceArray = None
|
||||||
|
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlaxBaseModelOutputWithPooling(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for model's outputs that also contains a pooling of the last hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||||
|
Sequence of hidden-states at the output of the last layer of the model.
|
||||||
|
pooler_output (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, hidden_size)`):
|
||||||
|
Last layer hidden-state of the first token of the sequence (classification token) further processed by a
|
||||||
|
Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
|
||||||
|
prediction (classification) objective during pretraining.
|
||||||
|
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||||
|
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
last_hidden_state: jax_xla.DeviceArray = None
|
||||||
|
pooler_output: jax_xla.DeviceArray = None
|
||||||
|
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlaxMaskedLMOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for masked language models outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||||
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
|
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||||
|
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logits: jax_xla.DeviceArray = None
|
||||||
|
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlaxNextSentencePredictorOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for outputs of models predicting if two sentences are consecutive or not.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`):
|
||||||
|
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
||||||
|
before SoftMax).
|
||||||
|
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||||
|
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logits: jax_xla.DeviceArray = None
|
||||||
|
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlaxSequenceClassifierOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for outputs of sentence classification models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.num_labels)`):
|
||||||
|
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||||
|
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||||
|
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logits: jax_xla.DeviceArray = None
|
||||||
|
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlaxMultipleChoiceModelOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for outputs of multiple choice models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, num_choices)`):
|
||||||
|
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
|
||||||
|
|
||||||
|
Classification scores (before SoftMax).
|
||||||
|
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||||
|
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logits: jax_xla.DeviceArray = None
|
||||||
|
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlaxTokenClassifierOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for outputs of token classification models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
|
||||||
|
Classification scores (before SoftMax).
|
||||||
|
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||||
|
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logits: jax_xla.DeviceArray = None
|
||||||
|
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlaxQuestionAnsweringModelOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for outputs of question answering models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
|
Span-start scores (before SoftMax).
|
||||||
|
end_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
|
Span-end scores (before SoftMax).
|
||||||
|
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||||
|
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
start_logits: jax_xla.DeviceArray = None
|
||||||
|
end_logits: jax_xla.DeviceArray = None
|
||||||
|
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
@@ -32,12 +32,14 @@ from .file_utils import (
|
|||||||
FLAX_WEIGHTS_NAME,
|
FLAX_WEIGHTS_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
cached_path,
|
cached_path,
|
||||||
copy_func,
|
copy_func,
|
||||||
hf_bucket_url,
|
hf_bucket_url,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_remote_url,
|
is_remote_url,
|
||||||
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
|
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
@@ -432,3 +434,22 @@ def overwrite_call_docstring(model_class, docstring):
|
|||||||
model_class.__call__.__doc__ = None
|
model_class.__call__.__doc__ = None
|
||||||
# set correct docstring
|
# set correct docstring
|
||||||
model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)
|
model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)
|
||||||
|
|
||||||
|
|
||||||
|
def append_call_sample_docstring(model_class, tokenizer_class, checkpoint, output_type, config_class, mask=None):
|
||||||
|
model_class.__call__ = copy_func(model_class.__call__)
|
||||||
|
model_class.__call__ = add_code_sample_docstrings(
|
||||||
|
tokenizer_class=tokenizer_class,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
output_type=output_type,
|
||||||
|
config_class=config_class,
|
||||||
|
model_cls=model_class.__name__,
|
||||||
|
)(model_class.__call__)
|
||||||
|
|
||||||
|
|
||||||
|
def append_replace_return_docstrings(model_class, output_type, config_class):
|
||||||
|
model_class.__call__ = copy_func(model_class.__call__)
|
||||||
|
model_class.__call__ = replace_return_docstrings(
|
||||||
|
output_type=output_type,
|
||||||
|
config_class=config_class,
|
||||||
|
)(model_class.__call__)
|
||||||
|
|||||||
@@ -13,30 +13,79 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Callable, Tuple
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import jaxlib.xla_extension as jax_xla
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict
|
||||||
from flax.linen import dot_product_attention
|
from flax.linen import dot_product_attention
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||||
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, overwrite_call_docstring
|
from ...modeling_flax_outputs import (
|
||||||
|
FlaxBaseModelOutput,
|
||||||
|
FlaxBaseModelOutputWithPooling,
|
||||||
|
FlaxMaskedLMOutput,
|
||||||
|
FlaxMultipleChoiceModelOutput,
|
||||||
|
FlaxNextSentencePredictorOutput,
|
||||||
|
FlaxQuestionAnsweringModelOutput,
|
||||||
|
FlaxSequenceClassifierOutput,
|
||||||
|
FlaxTokenClassifierOutput,
|
||||||
|
)
|
||||||
|
from ...modeling_flax_utils import (
|
||||||
|
ACT2FN,
|
||||||
|
FlaxPreTrainedModel,
|
||||||
|
append_call_sample_docstring,
|
||||||
|
append_replace_return_docstrings,
|
||||||
|
overwrite_call_docstring,
|
||||||
|
)
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .configuration_bert import BertConfig
|
from .configuration_bert import BertConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
_CHECKPOINT_FOR_DOC = "bert-base-uncased"
|
||||||
_CONFIG_FOR_DOC = "BertConfig"
|
_CONFIG_FOR_DOC = "BertConfig"
|
||||||
_TOKENIZER_FOR_DOC = "BertTokenizer"
|
_TOKENIZER_FOR_DOC = "BertTokenizer"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlaxBertForPreTrainingOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Output type of :class:`~transformers.BertForPreTraining`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prediction_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||||
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
|
seq_relationship_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`):
|
||||||
|
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
||||||
|
before SoftMax).
|
||||||
|
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
||||||
|
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prediction_logits: jax_xla.DeviceArray = None
|
||||||
|
seq_relationship_logits: jax_xla.DeviceArray = None
|
||||||
|
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
||||||
|
|
||||||
|
|
||||||
BERT_START_DOCSTRING = r"""
|
BERT_START_DOCSTRING = r"""
|
||||||
|
|
||||||
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
|
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
|
||||||
@@ -166,7 +215,7 @@ class FlaxBertSelfAttention(nn.Module):
|
|||||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic=True):
|
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
||||||
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||||
|
|
||||||
query_states = self.query(hidden_states).reshape(
|
query_states = self.query(hidden_states).reshape(
|
||||||
@@ -208,7 +257,12 @@ class FlaxBertSelfAttention(nn.Module):
|
|||||||
precision=None,
|
precision=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
return attn_output.reshape(attn_output.shape[:2] + (-1,))
|
outputs = (attn_output.reshape(attn_output.shape[:2] + (-1,)),)
|
||||||
|
|
||||||
|
# TODO: at the moment it's not possible to retrieve attn_weights from
|
||||||
|
# dot_product_attention, but should be in the future -> add functionality then
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertSelfOutput(nn.Module):
|
class FlaxBertSelfOutput(nn.Module):
|
||||||
@@ -239,13 +293,22 @@ class FlaxBertAttention(nn.Module):
|
|||||||
self.self = FlaxBertSelfAttention(self.config, dtype=self.dtype)
|
self.self = FlaxBertSelfAttention(self.config, dtype=self.dtype)
|
||||||
self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype)
|
self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic=True):
|
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
||||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||||
attn_output = self.self(hidden_states, attention_mask, deterministic=deterministic)
|
attn_outputs = self.self(
|
||||||
|
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
||||||
|
)
|
||||||
|
attn_output = attn_outputs[0]
|
||||||
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
||||||
return hidden_states
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += attn_outputs[1]
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertIntermediate(nn.Module):
|
class FlaxBertIntermediate(nn.Module):
|
||||||
@@ -295,11 +358,20 @@ class FlaxBertLayer(nn.Module):
|
|||||||
self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype)
|
self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype)
|
||||||
self.output = FlaxBertOutput(self.config, dtype=self.dtype)
|
self.output = FlaxBertOutput(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
def __call__(self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False):
|
||||||
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
|
attention_outputs = self.attention(
|
||||||
|
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
||||||
|
)
|
||||||
|
attention_output = attention_outputs[0]
|
||||||
|
|
||||||
hidden_states = self.intermediate(attention_output)
|
hidden_states = self.intermediate(attention_output)
|
||||||
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
|
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
|
||||||
return hidden_states
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (attention_outputs[1],)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertLayerCollection(nn.Module):
|
class FlaxBertLayerCollection(nn.Module):
|
||||||
@@ -311,10 +383,40 @@ class FlaxBertLayerCollection(nn.Module):
|
|||||||
FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
deterministic: bool = True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
output_hidden_states: bool = False,
|
||||||
|
return_dict: bool = True,
|
||||||
|
):
|
||||||
|
all_attentions = () if output_attentions else None
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic)
|
if output_hidden_states:
|
||||||
return hidden_states
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
layer_outputs = layer(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_attentions += (layer_outputs[1],)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in outputs if v is not None)
|
||||||
|
|
||||||
|
return FlaxBaseModelOutput(
|
||||||
|
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertEncoder(nn.Module):
|
class FlaxBertEncoder(nn.Module):
|
||||||
@@ -324,8 +426,23 @@ class FlaxBertEncoder(nn.Module):
|
|||||||
def setup(self):
|
def setup(self):
|
||||||
self.layer = FlaxBertLayerCollection(self.config, dtype=self.dtype)
|
self.layer = FlaxBertLayerCollection(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
def __call__(
|
||||||
return self.layer(hidden_states, attention_mask, deterministic=deterministic)
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
deterministic: bool = True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
output_hidden_states: bool = False,
|
||||||
|
return_dict: bool = True,
|
||||||
|
):
|
||||||
|
return self.layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertPooler(nn.Module):
|
class FlaxBertPooler(nn.Module):
|
||||||
@@ -456,7 +573,21 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
params: dict = None,
|
params: dict = None,
|
||||||
dropout_rng: PRNGKey = None,
|
dropout_rng: PRNGKey = None,
|
||||||
train: bool = False,
|
train: bool = False,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Currently attention scores cannot be returned. Please set `output_attentions` to False for now."
|
||||||
|
)
|
||||||
|
|
||||||
# init input tensors if not passed
|
# init input tensors if not passed
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = jnp.ones_like(input_ids)
|
token_type_ids = jnp.ones_like(input_ids)
|
||||||
@@ -479,6 +610,9 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
jnp.array(token_type_ids, dtype="i4"),
|
jnp.array(token_type_ids, dtype="i4"),
|
||||||
jnp.array(position_ids, dtype="i4"),
|
jnp.array(position_ids, dtype="i4"),
|
||||||
not train,
|
not train,
|
||||||
|
output_attentions,
|
||||||
|
output_hidden_states,
|
||||||
|
return_dict,
|
||||||
rngs=rngs,
|
rngs=rngs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -493,17 +627,43 @@ class FlaxBertModule(nn.Module):
|
|||||||
self.encoder = FlaxBertEncoder(self.config, dtype=self.dtype)
|
self.encoder = FlaxBertEncoder(self.config, dtype=self.dtype)
|
||||||
self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)
|
self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
deterministic: bool = True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
output_hidden_states: bool = False,
|
||||||
|
return_dict: bool = True,
|
||||||
|
):
|
||||||
hidden_states = self.embeddings(
|
hidden_states = self.embeddings(
|
||||||
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
|
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
|
||||||
)
|
)
|
||||||
hidden_states = self.encoder(hidden_states, attention_mask, deterministic=deterministic)
|
outputs = self.encoder(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
|
||||||
|
|
||||||
if not self.add_pooling_layer:
|
if not return_dict:
|
||||||
return hidden_states
|
# if pooled is None, don't return it
|
||||||
|
if pooled is None:
|
||||||
|
return (hidden_states,) + outputs[1:]
|
||||||
|
return (hidden_states, pooled) + outputs[1:]
|
||||||
|
|
||||||
pooled = self.pooler(hidden_states)
|
return FlaxBaseModelOutputWithPooling(
|
||||||
return hidden_states, pooled
|
last_hidden_state=hidden_states,
|
||||||
|
pooler_output=pooled,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@@ -514,6 +674,11 @@ class FlaxBertModel(FlaxBertPreTrainedModel):
|
|||||||
module_class = FlaxBertModule
|
module_class = FlaxBertModule
|
||||||
|
|
||||||
|
|
||||||
|
append_call_sample_docstring(
|
||||||
|
FlaxBertModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertForPreTrainingModule(nn.Module):
|
class FlaxBertForPreTrainingModule(nn.Module):
|
||||||
config: BertConfig
|
config: BertConfig
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
@@ -523,11 +688,27 @@ class FlaxBertForPreTrainingModule(nn.Module):
|
|||||||
self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype)
|
self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
self,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
deterministic: bool = True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
output_hidden_states: bool = False,
|
||||||
|
return_dict: bool = True,
|
||||||
):
|
):
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
hidden_states, pooled_output = self.bert(
|
outputs = self.bert(
|
||||||
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.tie_word_embeddings:
|
if self.config.tie_word_embeddings:
|
||||||
@@ -535,11 +716,22 @@ class FlaxBertForPreTrainingModule(nn.Module):
|
|||||||
else:
|
else:
|
||||||
shared_embedding = None
|
shared_embedding = None
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
pooled_output = outputs[1]
|
||||||
|
|
||||||
prediction_scores, seq_relationship_score = self.cls(
|
prediction_scores, seq_relationship_score = self.cls(
|
||||||
hidden_states, pooled_output, shared_embedding=shared_embedding
|
hidden_states, pooled_output, shared_embedding=shared_embedding
|
||||||
)
|
)
|
||||||
|
|
||||||
return (prediction_scores, seq_relationship_score)
|
if not return_dict:
|
||||||
|
return (prediction_scores, seq_relationship_score) + outputs[2:]
|
||||||
|
|
||||||
|
return FlaxBertForPreTrainingOutput(
|
||||||
|
prediction_logits=prediction_scores,
|
||||||
|
seq_relationship_logits=seq_relationship_score,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@@ -553,6 +745,32 @@ class FlaxBertForPreTraining(FlaxBertPreTrainedModel):
|
|||||||
module_class = FlaxBertForPreTrainingModule
|
module_class = FlaxBertForPreTrainingModule
|
||||||
|
|
||||||
|
|
||||||
|
FLAX_BERT_FOR_PRETRAINING_DOCSTRING = """
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import BertTokenizer, FlaxBertForPreTraining
|
||||||
|
|
||||||
|
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||||
|
>>> model = FlaxBertForPreTraining.from_pretrained('bert-base-uncased')
|
||||||
|
|
||||||
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
|
||||||
|
>>> outputs = model(**inputs)
|
||||||
|
|
||||||
|
>>> prediction_logits = outputs.prediction_logits
|
||||||
|
>>> seq_relationship_logits = outputs.seq_relationship_logits
|
||||||
|
"""
|
||||||
|
|
||||||
|
overwrite_call_docstring(
|
||||||
|
FlaxBertForPreTraining,
|
||||||
|
BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BERT_FOR_PRETRAINING_DOCSTRING,
|
||||||
|
)
|
||||||
|
append_replace_return_docstrings(
|
||||||
|
FlaxBertForPreTraining, output_type=FlaxBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertForMaskedLMModule(nn.Module):
|
class FlaxBertForMaskedLMModule(nn.Module):
|
||||||
config: BertConfig
|
config: BertConfig
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
@@ -562,11 +780,29 @@ class FlaxBertForMaskedLMModule(nn.Module):
|
|||||||
self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)
|
self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
self,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
deterministic: bool = True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
output_hidden_states: bool = False,
|
||||||
|
return_dict: bool = True,
|
||||||
):
|
):
|
||||||
# Model
|
# Model
|
||||||
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
outputs = self.bert(
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
if self.config.tie_word_embeddings:
|
if self.config.tie_word_embeddings:
|
||||||
shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
||||||
else:
|
else:
|
||||||
@@ -575,7 +811,14 @@ class FlaxBertForMaskedLMModule(nn.Module):
|
|||||||
# Compute the prediction scores
|
# Compute the prediction scores
|
||||||
logits = self.cls(hidden_states, shared_embedding=shared_embedding)
|
logits = self.cls(hidden_states, shared_embedding=shared_embedding)
|
||||||
|
|
||||||
return (logits,)
|
if not return_dict:
|
||||||
|
return (logits,) + outputs[1:]
|
||||||
|
|
||||||
|
return FlaxMaskedLMOutput(
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
||||||
@@ -583,6 +826,11 @@ class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
|
|||||||
module_class = FlaxBertForMaskedLMModule
|
module_class = FlaxBertForMaskedLMModule
|
||||||
|
|
||||||
|
|
||||||
|
append_call_sample_docstring(
|
||||||
|
FlaxBertForMaskedLM, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertForNextSentencePredictionModule(nn.Module):
|
class FlaxBertForNextSentencePredictionModule(nn.Module):
|
||||||
config: BertConfig
|
config: BertConfig
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
@@ -592,15 +840,41 @@ class FlaxBertForNextSentencePredictionModule(nn.Module):
|
|||||||
self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype)
|
self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
self,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
deterministic: bool = True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
output_hidden_states: bool = False,
|
||||||
|
return_dict: bool = True,
|
||||||
):
|
):
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
_, pooled_output = self.bert(
|
outputs = self.bert(
|
||||||
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
pooled_output = outputs[1]
|
||||||
seq_relationship_scores = self.cls(pooled_output)
|
seq_relationship_scores = self.cls(pooled_output)
|
||||||
return (seq_relationship_scores,)
|
|
||||||
|
if not return_dict:
|
||||||
|
return (seq_relationship_scores,) + outputs[2:]
|
||||||
|
|
||||||
|
return FlaxNextSentencePredictorOutput(
|
||||||
|
logits=seq_relationship_scores,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@@ -611,6 +885,35 @@ class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel):
|
|||||||
module_class = FlaxBertForNextSentencePredictionModule
|
module_class = FlaxBertForNextSentencePredictionModule
|
||||||
|
|
||||||
|
|
||||||
|
FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING = """
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import BertTokenizer, FlaxBertForNextSentencePrediction
|
||||||
|
|
||||||
|
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||||
|
>>> model = FlaxBertForNextSentencePrediction.from_pretrained('bert-base-uncased')
|
||||||
|
|
||||||
|
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||||
|
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
||||||
|
>>> encoding = tokenizer(prompt, next_sentence, return_tensors='jax')
|
||||||
|
|
||||||
|
>>> outputs = model(**encoding)
|
||||||
|
>>> logits = outputs.logits
|
||||||
|
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
overwrite_call_docstring(
|
||||||
|
FlaxBertForNextSentencePrediction,
|
||||||
|
BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING,
|
||||||
|
)
|
||||||
|
append_replace_return_docstrings(
|
||||||
|
FlaxBertForNextSentencePrediction, output_type=FlaxNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertForSequenceClassificationModule(nn.Module):
|
class FlaxBertForSequenceClassificationModule(nn.Module):
|
||||||
config: BertConfig
|
config: BertConfig
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
@@ -624,17 +927,40 @@ class FlaxBertForSequenceClassificationModule(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
self,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
deterministic: bool = True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
output_hidden_states: bool = False,
|
||||||
|
return_dict: bool = True,
|
||||||
):
|
):
|
||||||
# Model
|
# Model
|
||||||
_, pooled_output = self.bert(
|
outputs = self.bert(
|
||||||
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
pooled_output = outputs[1]
|
||||||
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
|
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
|
||||||
logits = self.classifier(pooled_output)
|
logits = self.classifier(pooled_output)
|
||||||
|
|
||||||
return (logits,)
|
if not return_dict:
|
||||||
|
return (logits,) + outputs[2:]
|
||||||
|
|
||||||
|
return FlaxSequenceClassifierOutput(
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@@ -648,6 +974,15 @@ class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel):
|
|||||||
module_class = FlaxBertForSequenceClassificationModule
|
module_class = FlaxBertForSequenceClassificationModule
|
||||||
|
|
||||||
|
|
||||||
|
append_call_sample_docstring(
|
||||||
|
FlaxBertForSequenceClassification,
|
||||||
|
_TOKENIZER_FOR_DOC,
|
||||||
|
_CHECKPOINT_FOR_DOC,
|
||||||
|
FlaxSequenceClassifierOutput,
|
||||||
|
_CONFIG_FOR_DOC,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertForMultipleChoiceModule(nn.Module):
|
class FlaxBertForMultipleChoiceModule(nn.Module):
|
||||||
config: BertConfig
|
config: BertConfig
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
@@ -658,7 +993,15 @@ class FlaxBertForMultipleChoiceModule(nn.Module):
|
|||||||
self.classifier = nn.Dense(1, dtype=self.dtype)
|
self.classifier = nn.Dense(1, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
self,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
deterministic: bool = True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
output_hidden_states: bool = False,
|
||||||
|
return_dict: bool = True,
|
||||||
):
|
):
|
||||||
num_choices = input_ids.shape[1]
|
num_choices = input_ids.shape[1]
|
||||||
input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
|
input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
|
||||||
@@ -667,16 +1010,31 @@ class FlaxBertForMultipleChoiceModule(nn.Module):
|
|||||||
position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
|
position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
_, pooled_output = self.bert(
|
outputs = self.bert(
|
||||||
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
pooled_output = outputs[1]
|
||||||
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
|
pooled_output = self.dropout(pooled_output, deterministic=deterministic)
|
||||||
logits = self.classifier(pooled_output)
|
logits = self.classifier(pooled_output)
|
||||||
|
|
||||||
reshaped_logits = logits.reshape(-1, num_choices)
|
reshaped_logits = logits.reshape(-1, num_choices)
|
||||||
|
|
||||||
return (reshaped_logits,)
|
if not return_dict:
|
||||||
|
return (reshaped_logits,) + outputs[2:]
|
||||||
|
|
||||||
|
return FlaxMultipleChoiceModelOutput(
|
||||||
|
logits=reshaped_logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@@ -690,10 +1048,12 @@ class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel):
|
|||||||
module_class = FlaxBertForMultipleChoiceModule
|
module_class = FlaxBertForMultipleChoiceModule
|
||||||
|
|
||||||
|
|
||||||
# adapt docstring slightly for FlaxBertForMultipleChoice
|
|
||||||
overwrite_call_docstring(
|
overwrite_call_docstring(
|
||||||
FlaxBertForMultipleChoice, BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
|
FlaxBertForMultipleChoice, BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
|
||||||
)
|
)
|
||||||
|
append_call_sample_docstring(
|
||||||
|
FlaxBertForMultipleChoice, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxMultipleChoiceModelOutput, _CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertForTokenClassificationModule(nn.Module):
|
class FlaxBertForTokenClassificationModule(nn.Module):
|
||||||
@@ -706,15 +1066,40 @@ class FlaxBertForTokenClassificationModule(nn.Module):
|
|||||||
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
self,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
deterministic: bool = True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
output_hidden_states: bool = False,
|
||||||
|
return_dict: bool = True,
|
||||||
):
|
):
|
||||||
# Model
|
# Model
|
||||||
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
outputs = self.bert(
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||||
logits = self.classifier(hidden_states)
|
logits = self.classifier(hidden_states)
|
||||||
|
|
||||||
return (logits,)
|
if not return_dict:
|
||||||
|
return (logits,) + outputs[1:]
|
||||||
|
|
||||||
|
return FlaxTokenClassifierOutput(
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@@ -728,6 +1113,11 @@ class FlaxBertForTokenClassification(FlaxBertPreTrainedModel):
|
|||||||
module_class = FlaxBertForTokenClassificationModule
|
module_class = FlaxBertForTokenClassificationModule
|
||||||
|
|
||||||
|
|
||||||
|
append_call_sample_docstring(
|
||||||
|
FlaxBertForTokenClassification, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxTokenClassifierOutput, _CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertForQuestionAnsweringModule(nn.Module):
|
class FlaxBertForQuestionAnsweringModule(nn.Module):
|
||||||
config: BertConfig
|
config: BertConfig
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
@@ -737,17 +1127,44 @@ class FlaxBertForQuestionAnsweringModule(nn.Module):
|
|||||||
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
self,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
deterministic: bool = True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
output_hidden_states: bool = False,
|
||||||
|
return_dict: bool = True,
|
||||||
):
|
):
|
||||||
# Model
|
# Model
|
||||||
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
outputs = self.bert(
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
|
||||||
logits = self.qa_outputs(hidden_states)
|
logits = self.qa_outputs(hidden_states)
|
||||||
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
|
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
|
||||||
start_logits = start_logits.squeeze(-1)
|
start_logits = start_logits.squeeze(-1)
|
||||||
end_logits = end_logits.squeeze(-1)
|
end_logits = end_logits.squeeze(-1)
|
||||||
|
|
||||||
return (start_logits, end_logits)
|
if not return_dict:
|
||||||
|
return (start_logits, end_logits) + outputs[1:]
|
||||||
|
|
||||||
|
return FlaxQuestionAnsweringModelOutput(
|
||||||
|
start_logits=start_logits,
|
||||||
|
end_logits=end_logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@@ -759,3 +1176,12 @@ class FlaxBertForQuestionAnsweringModule(nn.Module):
|
|||||||
)
|
)
|
||||||
class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel):
|
class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel):
|
||||||
module_class = FlaxBertForQuestionAnsweringModule
|
module_class = FlaxBertForQuestionAnsweringModule
|
||||||
|
|
||||||
|
|
||||||
|
append_call_sample_docstring(
|
||||||
|
FlaxBertForQuestionAnswering,
|
||||||
|
_TOKENIZER_FOR_DOC,
|
||||||
|
_CHECKPOINT_FOR_DOC,
|
||||||
|
FlaxQuestionAnsweringModelOutput,
|
||||||
|
_CONFIG_FOR_DOC,
|
||||||
|
)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
@@ -23,13 +23,15 @@ from jax import lax
|
|||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||||
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
|
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling
|
||||||
|
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .configuration_roberta import RobertaConfig
|
from .configuration_roberta import RobertaConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
_CHECKPOINT_FOR_DOC = "roberta-base"
|
||||||
_CONFIG_FOR_DOC = "RobertaConfig"
|
_CONFIG_FOR_DOC = "RobertaConfig"
|
||||||
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
|
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
|
||||||
|
|
||||||
@@ -181,7 +183,7 @@ class FlaxRobertaSelfAttention(nn.Module):
|
|||||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic=True):
|
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
||||||
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||||
|
|
||||||
query_states = self.query(hidden_states).reshape(
|
query_states = self.query(hidden_states).reshape(
|
||||||
@@ -223,7 +225,12 @@ class FlaxRobertaSelfAttention(nn.Module):
|
|||||||
precision=None,
|
precision=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
return attn_output.reshape(attn_output.shape[:2] + (-1,))
|
outputs = (attn_output.reshape(attn_output.shape[:2] + (-1,)),)
|
||||||
|
|
||||||
|
# TODO: at the moment it's not possible to retrieve attn_weights from
|
||||||
|
# dot_product_attention, but should be in the future -> add functionality then
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Roberta
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Roberta
|
||||||
@@ -256,13 +263,22 @@ class FlaxRobertaAttention(nn.Module):
|
|||||||
self.self = FlaxRobertaSelfAttention(self.config, dtype=self.dtype)
|
self.self = FlaxRobertaSelfAttention(self.config, dtype=self.dtype)
|
||||||
self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype)
|
self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic=True):
|
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False):
|
||||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||||
attn_output = self.self(hidden_states, attention_mask, deterministic=deterministic)
|
attn_outputs = self.self(
|
||||||
|
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
||||||
|
)
|
||||||
|
attn_output = attn_outputs[0]
|
||||||
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
||||||
return hidden_states
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += attn_outputs[1]
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
|
||||||
@@ -315,11 +331,20 @@ class FlaxRobertaLayer(nn.Module):
|
|||||||
self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype)
|
self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype)
|
||||||
self.output = FlaxRobertaOutput(self.config, dtype=self.dtype)
|
self.output = FlaxRobertaOutput(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
def __call__(self, hidden_states, attention_mask, deterministic: bool = True, output_attentions: bool = False):
|
||||||
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
|
attention_outputs = self.attention(
|
||||||
|
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
|
||||||
|
)
|
||||||
|
attention_output = attention_outputs[0]
|
||||||
|
|
||||||
hidden_states = self.intermediate(attention_output)
|
hidden_states = self.intermediate(attention_output)
|
||||||
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
|
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
|
||||||
return hidden_states
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (attention_outputs[1],)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta
|
||||||
@@ -332,10 +357,40 @@ class FlaxRobertaLayerCollection(nn.Module):
|
|||||||
FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
def __call__(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
deterministic: bool = True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
output_hidden_states: bool = False,
|
||||||
|
return_dict: bool = True,
|
||||||
|
):
|
||||||
|
all_attentions = () if output_attentions else None
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic)
|
if output_hidden_states:
|
||||||
return hidden_states
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
layer_outputs = layer(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_attentions += (layer_outputs[1],)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in outputs if v is not None)
|
||||||
|
|
||||||
|
return FlaxBaseModelOutput(
|
||||||
|
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta
|
||||||
@@ -346,8 +401,23 @@ class FlaxRobertaEncoder(nn.Module):
|
|||||||
def setup(self):
|
def setup(self):
|
||||||
self.layer = FlaxRobertaLayerCollection(self.config, dtype=self.dtype)
|
self.layer = FlaxRobertaLayerCollection(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
def __call__(
|
||||||
return self.layer(hidden_states, attention_mask, deterministic=deterministic)
|
self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
deterministic: bool = True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
output_hidden_states: bool = False,
|
||||||
|
return_dict: bool = True,
|
||||||
|
):
|
||||||
|
return self.layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
|
||||||
@@ -412,7 +482,21 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
params: dict = None,
|
params: dict = None,
|
||||||
dropout_rng: PRNGKey = None,
|
dropout_rng: PRNGKey = None,
|
||||||
train: bool = False,
|
train: bool = False,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Currently attention scores cannot be returned." "Please set `output_attentions` to False for now."
|
||||||
|
)
|
||||||
|
|
||||||
# init input tensors if not passed
|
# init input tensors if not passed
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = jnp.ones_like(input_ids)
|
token_type_ids = jnp.ones_like(input_ids)
|
||||||
@@ -435,6 +519,9 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
jnp.array(token_type_ids, dtype="i4"),
|
jnp.array(token_type_ids, dtype="i4"),
|
||||||
jnp.array(position_ids, dtype="i4"),
|
jnp.array(position_ids, dtype="i4"),
|
||||||
not train,
|
not train,
|
||||||
|
output_attentions,
|
||||||
|
output_hidden_states,
|
||||||
|
return_dict,
|
||||||
rngs=rngs,
|
rngs=rngs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -450,17 +537,43 @@ class FlaxRobertaModule(nn.Module):
|
|||||||
self.encoder = FlaxRobertaEncoder(self.config, dtype=self.dtype)
|
self.encoder = FlaxRobertaEncoder(self.config, dtype=self.dtype)
|
||||||
self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
|
self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
deterministic: bool = True,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
output_hidden_states: bool = False,
|
||||||
|
return_dict: bool = True,
|
||||||
|
):
|
||||||
hidden_states = self.embeddings(
|
hidden_states = self.embeddings(
|
||||||
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
|
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
|
||||||
)
|
)
|
||||||
hidden_states = self.encoder(hidden_states, attention_mask, deterministic=deterministic)
|
outputs = self.encoder(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
deterministic=deterministic,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
|
||||||
|
|
||||||
if not self.add_pooling_layer:
|
if not return_dict:
|
||||||
return hidden_states
|
# if pooled is None, don't return it
|
||||||
|
if pooled is None:
|
||||||
|
return (hidden_states,) + outputs[1:]
|
||||||
|
return (hidden_states, pooled) + outputs[1:]
|
||||||
|
|
||||||
pooled = self.pooler(hidden_states)
|
return FlaxBaseModelOutputWithPooling(
|
||||||
return hidden_states, pooled
|
last_hidden_state=hidden_states,
|
||||||
|
pooler_output=pooled,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
@@ -469,3 +582,8 @@ class FlaxRobertaModule(nn.Module):
|
|||||||
)
|
)
|
||||||
class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
|
class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
|
||||||
module_class = FlaxRobertaModule
|
module_class = FlaxRobertaModule
|
||||||
|
|
||||||
|
|
||||||
|
append_call_sample_docstring(
|
||||||
|
FlaxRobertaModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
|
|||||||
@@ -998,7 +998,6 @@ class ModelTesterMixin:
|
|||||||
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
||||||
|
|
||||||
def test_model_outputs_equivalence(self):
|
def test_model_outputs_equivalence(self):
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
def set_nan_tensor_to_zero(t):
|
def set_nan_tensor_to_zero(t):
|
||||||
|
|||||||
@@ -13,8 +13,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import inspect
|
||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -28,6 +30,7 @@ if is_flax_available():
|
|||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
import jaxlib.xla_extension as jax_xla
|
||||||
from transformers.modeling_flax_pytorch_utils import (
|
from transformers.modeling_flax_pytorch_utils import (
|
||||||
convert_pytorch_state_dict_to_flax,
|
convert_pytorch_state_dict_to_flax,
|
||||||
load_flax_weights_in_pytorch_model,
|
load_flax_weights_in_pytorch_model,
|
||||||
@@ -77,6 +80,7 @@ class FlaxModelTesterMixin:
|
|||||||
inputs_dict = {
|
inputs_dict = {
|
||||||
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
|
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
|
||||||
for k, v in inputs_dict.items()
|
for k, v in inputs_dict.items()
|
||||||
|
if isinstance(v, (jax_xla.DeviceArray, np.ndarray))
|
||||||
}
|
}
|
||||||
|
|
||||||
return inputs_dict
|
return inputs_dict
|
||||||
@@ -85,6 +89,41 @@ class FlaxModelTesterMixin:
|
|||||||
diff = np.abs((a - b)).max()
|
diff = np.abs((a - b)).max()
|
||||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||||
|
|
||||||
|
def test_model_outputs_equivalence(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
def set_nan_tensor_to_zero(t):
|
||||||
|
t[t != t] = 0
|
||||||
|
return t
|
||||||
|
|
||||||
|
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||||
|
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
||||||
|
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||||
|
|
||||||
|
def recursive_check(tuple_object, dict_object):
|
||||||
|
if isinstance(tuple_object, (List, Tuple)):
|
||||||
|
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
||||||
|
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||||
|
elif tuple_object is None:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self.assert_almost_equals(
|
||||||
|
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), 1e-5
|
||||||
|
)
|
||||||
|
|
||||||
|
recursive_check(tuple_output, dict_output)
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||||
|
|
||||||
|
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||||
|
|
||||||
@is_pt_flax_cross_test
|
@is_pt_flax_cross_test
|
||||||
def test_equivalence_pt_to_flax(self):
|
def test_equivalence_pt_to_flax(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -108,7 +147,7 @@ class FlaxModelTesterMixin:
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||||
|
|
||||||
fx_outputs = fx_model(**prepared_inputs_dict)
|
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
|
||||||
@@ -117,7 +156,7 @@ class FlaxModelTesterMixin:
|
|||||||
pt_model.save_pretrained(tmpdirname)
|
pt_model.save_pretrained(tmpdirname)
|
||||||
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||||
|
|
||||||
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict)
|
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||||
)
|
)
|
||||||
@@ -149,7 +188,7 @@ class FlaxModelTesterMixin:
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||||
|
|
||||||
fx_outputs = fx_model(**prepared_inputs_dict)
|
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
|
||||||
@@ -171,17 +210,20 @@ class FlaxModelTesterMixin:
|
|||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
|
if model_class.__name__ != "FlaxBertModel":
|
||||||
|
continue
|
||||||
|
|
||||||
with self.subTest(model_class.__name__):
|
with self.subTest(model_class.__name__):
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
|
||||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
outputs = model(**prepared_inputs_dict)
|
outputs = model(**prepared_inputs_dict).to_tuple()
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
model.save_pretrained(tmpdirname)
|
model.save_pretrained(tmpdirname)
|
||||||
model_loaded = model_class.from_pretrained(tmpdirname)
|
model_loaded = model_class.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
outputs_loaded = model_loaded(**prepared_inputs_dict)
|
outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
|
||||||
for output_loaded, output in zip(outputs_loaded, outputs):
|
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||||
self.assert_almost_equals(output_loaded, output, 1e-3)
|
self.assert_almost_equals(output_loaded, output, 1e-3)
|
||||||
|
|
||||||
@@ -195,19 +237,47 @@ class FlaxModelTesterMixin:
|
|||||||
|
|
||||||
@jax.jit
|
@jax.jit
|
||||||
def model_jitted(input_ids, attention_mask=None, token_type_ids=None):
|
def model_jitted(input_ids, attention_mask=None, token_type_ids=None):
|
||||||
return model(input_ids, attention_mask, token_type_ids)
|
return model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
).to_tuple()
|
||||||
|
|
||||||
|
with self.subTest("JIT Enabled"):
|
||||||
|
jitted_outputs = model_jitted(**prepared_inputs_dict)
|
||||||
|
|
||||||
with self.subTest("JIT Disabled"):
|
with self.subTest("JIT Disabled"):
|
||||||
with jax.disable_jit():
|
with jax.disable_jit():
|
||||||
outputs = model_jitted(**prepared_inputs_dict)
|
outputs = model_jitted(**prepared_inputs_dict)
|
||||||
|
|
||||||
with self.subTest("JIT Enabled"):
|
|
||||||
jitted_outputs = model_jitted(**prepared_inputs_dict)
|
|
||||||
|
|
||||||
self.assertEqual(len(outputs), len(jitted_outputs))
|
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||||
self.assertEqual(jitted_output.shape, output.shape)
|
self.assertEqual(jitted_output.shape, output.shape)
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def model_jitted_return_dict(input_ids, attention_mask=None, token_type_ids=None):
|
||||||
|
return model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# jitted function cannot return OrderedDict
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
model_jitted_return_dict(**prepared_inputs_dict)
|
||||||
|
|
||||||
|
def test_forward_signature(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
signature = inspect.signature(model.__call__)
|
||||||
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||||
|
arg_names = [*signature.parameters.keys()]
|
||||||
|
|
||||||
|
expected_arg_names = ["input_ids", "attention_mask"]
|
||||||
|
self.assertListEqual(arg_names[:2], expected_arg_names)
|
||||||
|
|
||||||
def test_naming_convention(self):
|
def test_naming_convention(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model_class_name = model_class.__name__
|
model_class_name = model_class.__name__
|
||||||
@@ -218,3 +288,30 @@ class FlaxModelTesterMixin:
|
|||||||
module_cls = getattr(bert_modeling_flax_module, module_class_name)
|
module_cls = getattr(bert_modeling_flax_module, module_class_name)
|
||||||
|
|
||||||
self.assertIsNotNone(module_cls)
|
self.assertIsNotNone(module_cls)
|
||||||
|
|
||||||
|
def test_hidden_states_output(self):
|
||||||
|
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
hidden_states = outputs.hidden_states
|
||||||
|
|
||||||
|
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
|
||||||
|
seq_length = self.model_tester.seq_length
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
list(hidden_states[0].shape[-2:]),
|
||||||
|
[seq_length, self.model_tester.hidden_size],
|
||||||
|
)
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
inputs_dict["output_hidden_states"] = True
|
||||||
|
check_hidden_states_output(inputs_dict, config, model_class)
|
||||||
|
|
||||||
|
# check that output_hidden_states also work using config
|
||||||
|
del inputs_dict["output_hidden_states"]
|
||||||
|
config.output_hidden_states = True
|
||||||
|
|
||||||
|
check_hidden_states_output(inputs_dict, config, model_class)
|
||||||
|
|||||||
Reference in New Issue
Block a user