Flax/Jax documentation (#8331)
* First addition of Flax/Jax documentation Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * make style * Ensure input order match between Bert & Roberta Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Install dependencies "all" when building doc Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * wraps build_doc deps with "" Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Addressing @sgugger comments. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Use list to highlight JAX features. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Make style. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Let's not look to much into the future for now. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Style Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -281,7 +281,7 @@ jobs:
|
|||||||
- v0.4-build_doc-{{ checksum "setup.py" }}
|
- v0.4-build_doc-{{ checksum "setup.py" }}
|
||||||
- v0.4-{{ checksum "setup.py" }}
|
- v0.4-{{ checksum "setup.py" }}
|
||||||
- run: pip install --upgrade pip
|
- run: pip install --upgrade pip
|
||||||
- run: pip install .[tf,torch,sentencepiece,docs]
|
- run: pip install ."[all, docs]"
|
||||||
- save_cache:
|
- save_cache:
|
||||||
key: v0.4-build_doc-{{ checksum "setup.py" }}
|
key: v0.4-build_doc-{{ checksum "setup.py" }}
|
||||||
paths:
|
paths:
|
||||||
|
|||||||
@@ -188,3 +188,10 @@ TFBertForQuestionAnswering
|
|||||||
|
|
||||||
.. autoclass:: transformers.TFBertForQuestionAnswering
|
.. autoclass:: transformers.TFBertForQuestionAnswering
|
||||||
:members: call
|
:members: call
|
||||||
|
|
||||||
|
|
||||||
|
FlaxBertModel
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxBertModel
|
||||||
|
:members: __call__
|
||||||
|
|||||||
@@ -146,3 +146,10 @@ TFRobertaForQuestionAnswering
|
|||||||
|
|
||||||
.. autoclass:: transformers.TFRobertaForQuestionAnswering
|
.. autoclass:: transformers.TFRobertaForQuestionAnswering
|
||||||
:members: call
|
:members: call
|
||||||
|
|
||||||
|
|
||||||
|
FlaxRobertaModel
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxRobertaModel
|
||||||
|
:members: __call__
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import jax
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
from .configuration_bert import BertConfig
|
from .configuration_bert import BertConfig
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||||
from .modeling_flax_utils import FlaxPreTrainedModel, gelu
|
from .modeling_flax_utils import FlaxPreTrainedModel, gelu
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
@@ -35,13 +35,20 @@ _TOKENIZER_FOR_DOC = "BertTokenizer"
|
|||||||
|
|
||||||
BERT_START_DOCSTRING = r"""
|
BERT_START_DOCSTRING = r"""
|
||||||
|
|
||||||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
|
||||||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
|
generic methods the library implements for all its model (such as downloading, saving and converting weights from
|
||||||
pruning heads etc.)
|
PyTorch models)
|
||||||
|
|
||||||
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
|
This model is also a Flax Linen `flax.nn.Module
|
||||||
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
|
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
|
||||||
general usage and behavior.
|
Module and refer to the Flax documentation for all matter related to general usage and behavior.
|
||||||
|
|
||||||
|
Finally, this model supports inherent JAX features such as:
|
||||||
|
|
||||||
|
- `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
|
||||||
|
- `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
|
||||||
|
- `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
|
||||||
|
- `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
|
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
|
||||||
@@ -52,50 +59,32 @@ BERT_START_DOCSTRING = r"""
|
|||||||
|
|
||||||
BERT_INPUTS_DOCSTRING = r"""
|
BERT_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
|
input_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`):
|
||||||
Indices of input sequence tokens in the vocabulary.
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
|
||||||
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
|
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
|
||||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
:meth:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for
|
||||||
details.
|
details.
|
||||||
|
|
||||||
`What are input IDs? <../glossary.html#input-ids>`__
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
|
attention_mask (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
- 1 for tokens that are **not masked**,
|
- 1 for tokens that are **not masked**,
|
||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
token_type_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
|
||||||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
|
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
|
||||||
1]``:
|
1]``:
|
||||||
|
|
||||||
- 0 corresponds to a `sentence A` token,
|
- 0 corresponds to a `sentence A` token,
|
||||||
- 1 corresponds to a `sentence B` token.
|
- 1 corresponds to a `sentence B` token.
|
||||||
|
|
||||||
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
`What are token type IDs? <../glossary.html#token-type-ids>`__
|
||||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
position_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
|
||||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
||||||
config.max_position_embeddings - 1]``.
|
config.max_position_embeddings - 1]``.
|
||||||
|
|
||||||
`What are position IDs? <../glossary.html#position-ids>`_
|
|
||||||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
|
||||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
|
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
|
||||||
- 0 indicates the head is **masked**.
|
|
||||||
|
|
||||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
|
|
||||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
|
||||||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
|
|
||||||
vectors than the model's internal embedding lookup matrix.
|
|
||||||
output_attentions (:obj:`bool`, `optional`):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
|
||||||
tensors for more detail.
|
|
||||||
output_hidden_states (:obj:`bool`, `optional`):
|
|
||||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
|
||||||
more detail.
|
|
||||||
return_dict (:obj:`bool`, `optional`):
|
return_dict (:obj:`bool`, `optional`):
|
||||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
@@ -291,7 +280,7 @@ class FlaxBertModule(nn.Module):
|
|||||||
intermediate_size: int
|
intermediate_size: int
|
||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
|
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids):
|
||||||
|
|
||||||
# Embedding
|
# Embedding
|
||||||
embeddings = FlaxBertEmbeddings(
|
embeddings = FlaxBertEmbeddings(
|
||||||
@@ -410,7 +399,8 @@ class FlaxBertModel(FlaxPreTrainedModel):
|
|||||||
def module(self) -> nn.Module:
|
def module(self) -> nn.Module:
|
||||||
return self._module
|
return self._module
|
||||||
|
|
||||||
def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):
|
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
|
def __call__(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None):
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = jnp.ones_like(input_ids)
|
token_type_ids = jnp.ones_like(input_ids)
|
||||||
|
|
||||||
@@ -423,7 +413,7 @@ class FlaxBertModel(FlaxPreTrainedModel):
|
|||||||
return self.model.apply(
|
return self.model.apply(
|
||||||
{"params": self.params},
|
{"params": self.params},
|
||||||
jnp.array(input_ids, dtype="i4"),
|
jnp.array(input_ids, dtype="i4"),
|
||||||
|
jnp.array(attention_mask, dtype="i4"),
|
||||||
jnp.array(token_type_ids, dtype="i4"),
|
jnp.array(token_type_ids, dtype="i4"),
|
||||||
jnp.array(position_ids, dtype="i4"),
|
jnp.array(position_ids, dtype="i4"),
|
||||||
jnp.array(attention_mask, dtype="i4"),
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import jax
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
from .configuration_roberta import RobertaConfig
|
from .configuration_roberta import RobertaConfig
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||||
from .modeling_flax_utils import FlaxPreTrainedModel, gelu
|
from .modeling_flax_utils import FlaxPreTrainedModel, gelu
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
@@ -34,13 +34,20 @@ _TOKENIZER_FOR_DOC = "RobertaTokenizer"
|
|||||||
|
|
||||||
ROBERTA_START_DOCSTRING = r"""
|
ROBERTA_START_DOCSTRING = r"""
|
||||||
|
|
||||||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
|
||||||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
|
generic methods the library implements for all its model (such as downloading, saving and converting weights from
|
||||||
pruning heads etc.)
|
PyTorch models)
|
||||||
|
|
||||||
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
|
This model is also a Flax Linen `flax.nn.Module
|
||||||
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
|
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
|
||||||
general usage and behavior.
|
Module and refer to the Flax documentation for all matter related to general usage and behavior.
|
||||||
|
|
||||||
|
Finally, this model supports inherent JAX features such as:
|
||||||
|
|
||||||
|
- `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
|
||||||
|
- `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
|
||||||
|
- `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
|
||||||
|
- `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the
|
config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the
|
||||||
@@ -51,50 +58,32 @@ ROBERTA_START_DOCSTRING = r"""
|
|||||||
|
|
||||||
ROBERTA_INPUTS_DOCSTRING = r"""
|
ROBERTA_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
|
input_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`):
|
||||||
Indices of input sequence tokens in the vocabulary.
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
|
||||||
Indices can be obtained using :class:`~transformers.RobertaTokenizer`. See
|
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
|
||||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
:func:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for
|
||||||
details.
|
details.
|
||||||
|
|
||||||
`What are input IDs? <../glossary.html#input-ids>`__
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
|
attention_mask (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
- 1 for tokens that are **not masked**,
|
- 1 for tokens that are **not masked**,
|
||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
token_type_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
|
||||||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
|
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
|
||||||
1]``:
|
1]``:
|
||||||
|
|
||||||
- 0 corresponds to a `sentence A` token,
|
- 0 corresponds to a `sentence A` token,
|
||||||
- 1 corresponds to a `sentence B` token.
|
- 1 corresponds to a `sentence B` token.
|
||||||
|
|
||||||
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
`What are token type IDs? <../glossary.html#token-type-ids>`__
|
||||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
position_ids (:obj:`numpy.ndarray` of shape :obj:`({0})`, `optional`):
|
||||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
||||||
config.max_position_embeddings - 1]``.
|
config.max_position_embeddings - 1]``.
|
||||||
|
|
||||||
`What are position IDs? <../glossary.html#position-ids>`_
|
|
||||||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
|
||||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
|
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
|
||||||
- 0 indicates the head is **masked**.
|
|
||||||
|
|
||||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
|
|
||||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
|
||||||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
|
|
||||||
vectors than the model's internal embedding lookup matrix.
|
|
||||||
output_attentions (:obj:`bool`, `optional`):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
|
||||||
tensors for more detail.
|
|
||||||
output_hidden_states (:obj:`bool`, `optional`):
|
|
||||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
|
||||||
more detail.
|
|
||||||
return_dict (:obj:`bool`, `optional`):
|
return_dict (:obj:`bool`, `optional`):
|
||||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
@@ -302,7 +291,7 @@ class FlaxRobertaModule(nn.Module):
|
|||||||
intermediate_size: int
|
intermediate_size: int
|
||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
|
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids):
|
||||||
|
|
||||||
# Embedding
|
# Embedding
|
||||||
embeddings = FlaxRobertaEmbeddings(
|
embeddings = FlaxRobertaEmbeddings(
|
||||||
@@ -421,7 +410,8 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
|
|||||||
def module(self) -> nn.Module:
|
def module(self) -> nn.Module:
|
||||||
return self._module
|
return self._module
|
||||||
|
|
||||||
def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):
|
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
|
def __call__(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None):
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = jnp.ones_like(input_ids)
|
token_type_ids = jnp.ones_like(input_ids)
|
||||||
|
|
||||||
@@ -436,7 +426,7 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
|
|||||||
return self.model.apply(
|
return self.model.apply(
|
||||||
{"params": self.params},
|
{"params": self.params},
|
||||||
jnp.array(input_ids, dtype="i4"),
|
jnp.array(input_ids, dtype="i4"),
|
||||||
|
jnp.array(attention_mask, dtype="i4"),
|
||||||
jnp.array(token_type_ids, dtype="i4"),
|
jnp.array(token_type_ids, dtype="i4"),
|
||||||
jnp.array(position_ids, dtype="i4"),
|
jnp.array(position_ids, dtype="i4"),
|
||||||
jnp.array(attention_mask, dtype="i4"),
|
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user