[Flax] Correct flax docs (#12782)
* fix_torch_device_generate_test * remove @ * fix flax docs * correct more docs in flax * another correction * fix flax docs * Apply suggestions from code review
This commit is contained in:
committed by
GitHub
parent
a317e6c3be
commit
fbf468b057
@@ -299,3 +299,93 @@ TFSeq2SeqQuestionAnsweringModelOutput
|
|||||||
|
|
||||||
.. autoclass:: transformers.modeling_tf_outputs.TFSeq2SeqQuestionAnsweringModelOutput
|
.. autoclass:: transformers.modeling_tf_outputs.TFSeq2SeqQuestionAnsweringModelOutput
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
FlaxBaseModelOutput
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.modeling_flax_outputs.FlaxBaseModelOutput
|
||||||
|
|
||||||
|
|
||||||
|
FlaxBaseModelOutputWithPast
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPast
|
||||||
|
|
||||||
|
|
||||||
|
FlaxBaseModelOutputWithPooling
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPooling
|
||||||
|
|
||||||
|
|
||||||
|
FlaxBaseModelOutputWithPastAndCrossAttentions
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPastAndCrossAttentions
|
||||||
|
|
||||||
|
|
||||||
|
FlaxSeq2SeqModelOutput
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.modeling_flax_outputs.FlaxSeq2SeqModelOutput
|
||||||
|
|
||||||
|
|
||||||
|
FlaxCausalLMOutputWithCrossAttentions
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.modeling_flax_outputs.FlaxCausalLMOutputWithCrossAttentions
|
||||||
|
|
||||||
|
|
||||||
|
FlaxMaskedLMOutput
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.modeling_flax_outputs.FlaxMaskedLMOutput
|
||||||
|
|
||||||
|
|
||||||
|
FlaxSeq2SeqLMOutput
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.modeling_flax_outputs.FlaxSeq2SeqLMOutput
|
||||||
|
|
||||||
|
|
||||||
|
FlaxNextSentencePredictorOutput
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.modeling_flax_outputs.FlaxNextSentencePredictorOutput
|
||||||
|
|
||||||
|
|
||||||
|
FlaxSequenceClassifierOutput
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.modeling_flax_outputs.FlaxSequenceClassifierOutput
|
||||||
|
|
||||||
|
|
||||||
|
FlaxSeq2SeqSequenceClassifierOutput
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.modeling_flax_outputs.FlaxSeq2SeqSequenceClassifierOutput
|
||||||
|
|
||||||
|
|
||||||
|
FlaxMultipleChoiceModelOutput
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.modeling_flax_outputs.FlaxMultipleChoiceModelOutput
|
||||||
|
|
||||||
|
|
||||||
|
FlaxTokenClassifierOutput
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.modeling_flax_outputs.FlaxTokenClassifierOutput
|
||||||
|
|
||||||
|
|
||||||
|
FlaxQuestionAnsweringModelOutput
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.modeling_flax_outputs.FlaxQuestionAnsweringModelOutput
|
||||||
|
|
||||||
|
|
||||||
|
FlaxSeq2SeqQuestionAnsweringModelOutput
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.modeling_flax_outputs.FlaxSeq2SeqQuestionAnsweringModelOutput
|
||||||
|
|||||||
@@ -76,6 +76,9 @@ Bert specific outputs
|
|||||||
.. autoclass:: transformers.models.bert.modeling_tf_bert.TFBertForPreTrainingOutput
|
.. autoclass:: transformers.models.bert.modeling_tf_bert.TFBertForPreTrainingOutput
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
.. autoclass:: transformers.models.bert.modeling_flax_bert.FlaxBertForPreTrainingOutput
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
BertModel
|
BertModel
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|||||||
@@ -67,6 +67,22 @@ Wav2Vec2Processor
|
|||||||
:members: __call__, pad, from_pretrained, save_pretrained, batch_decode, decode, as_target_processor
|
:members: __call__, pad, from_pretrained, save_pretrained, batch_decode, decode, as_target_processor
|
||||||
|
|
||||||
|
|
||||||
|
Wav2Vec2 specific outputs
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2BaseModelOutput
|
||||||
|
:members:
|
||||||
|
|
||||||
|
.. autoclass:: transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput
|
||||||
|
:members:
|
||||||
|
|
||||||
|
.. autoclass:: transformers.models.wav2vec2.modeling_flax_wav2vec2.FlaxWav2Vec2BaseModelOutput
|
||||||
|
:members:
|
||||||
|
|
||||||
|
.. autoclass:: transformers.models.wav2vec2.modeling_flax_wav2vec2.FlaxWav2Vec2ForPreTrainingOutput
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
Wav2Vec2Model
|
Wav2Vec2Model
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -1063,7 +1063,7 @@ FLAX_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
|
|||||||
|
|
||||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax')
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax')
|
||||||
|
|
||||||
>>> outputs = model(**inputs, labels=labels)
|
>>> outputs = model(**inputs)
|
||||||
>>> logits = outputs.logits
|
>>> logits = outputs.logits
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -1122,9 +1122,10 @@ FLAX_CAUSAL_LM_SAMPLE = r"""
|
|||||||
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||||
|
|
||||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
|
||||||
>>> outputs = model(**inputs, labels=inputs["input_ids"])
|
>>> outputs = model(**inputs)
|
||||||
|
|
||||||
>>> logits = outputs.logits
|
>>> # retrieve logts for next token
|
||||||
|
>>> next_token_logits = outputs.logits[:, -1]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
FLAX_SAMPLE_DOCSTRINGS = {
|
FLAX_SAMPLE_DOCSTRINGS = {
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from flax.linen.attention import dot_product_attention_weights
|
|||||||
from jax import lax
|
from jax import lax
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings, replace_return_docstrings
|
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||||
from ...modeling_flax_outputs import (
|
from ...modeling_flax_outputs import (
|
||||||
FlaxBaseModelOutput,
|
FlaxBaseModelOutput,
|
||||||
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
||||||
@@ -1167,6 +1167,7 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
input_ids: jnp.ndarray,
|
input_ids: jnp.ndarray,
|
||||||
@@ -1520,7 +1521,7 @@ FLAX_BART_CONDITIONAL_GENERATION_DOCSTRING = """
|
|||||||
>>> input_ids = tokenizer([TXT], return_tensors='jax')['input_ids']
|
>>> input_ids = tokenizer([TXT], return_tensors='jax')['input_ids']
|
||||||
>>> logits = model(input_ids).logits
|
>>> logits = model(input_ids).logits
|
||||||
|
|
||||||
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
|
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero()[0].item()
|
||||||
>>> probs = jax.nn.softmax(logits[0, masked_index], axis=0)
|
>>> probs = jax.nn.softmax(logits[0, masked_index], axis=0)
|
||||||
>>> values, predictions = jax.lax.top_k(probs)
|
>>> values, predictions = jax.lax.top_k(probs)
|
||||||
|
|
||||||
|
|||||||
@@ -941,7 +941,7 @@ FLAX_CLIP_TEXT_MODEL_DOCSTRING = """
|
|||||||
|
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
>>> last_hidden_state = outputs.last_hidden_state
|
>>> last_hidden_state = outputs.last_hidden_state
|
||||||
>>> pooled_output = outputs.pooled_output # pooled (EOS token) states
|
>>> pooler_output = outputs.pooler_output # pooled (EOS token) states
|
||||||
"""
|
"""
|
||||||
|
|
||||||
overwrite_call_docstring(FlaxCLIPTextModel, CLIP_TEXT_INPUTS_DOCSTRING + FLAX_CLIP_TEXT_MODEL_DOCSTRING)
|
overwrite_call_docstring(FlaxCLIPTextModel, CLIP_TEXT_INPUTS_DOCSTRING + FLAX_CLIP_TEXT_MODEL_DOCSTRING)
|
||||||
@@ -997,7 +997,7 @@ FLAX_CLIP_VISION_MODEL_DOCSTRING = """
|
|||||||
|
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
>>> last_hidden_state = outputs.last_hidden_state
|
>>> last_hidden_state = outputs.last_hidden_state
|
||||||
>>> pooled_output = outputs.pooled_output # pooled CLS states
|
>>> pooler_output = outputs.pooler_output # pooled CLS states
|
||||||
"""
|
"""
|
||||||
|
|
||||||
overwrite_call_docstring(FlaxCLIPVisionModel, CLIP_VISION_INPUTS_DOCSTRING + FLAX_CLIP_VISION_MODEL_DOCSTRING)
|
overwrite_call_docstring(FlaxCLIPVisionModel, CLIP_VISION_INPUTS_DOCSTRING + FLAX_CLIP_VISION_MODEL_DOCSTRING)
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from flax.linen.attention import dot_product_attention_weights
|
|||||||
from jax import lax
|
from jax import lax
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings, replace_return_docstrings
|
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||||
from ...modeling_flax_outputs import (
|
from ...modeling_flax_outputs import (
|
||||||
FlaxBaseModelOutput,
|
FlaxBaseModelOutput,
|
||||||
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
||||||
@@ -45,7 +45,7 @@ from .configuration_marian import MarianConfig
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CHECKPOINT_FOR_DOC = "Helsinki-NLP/opus-mt-en-de'"
|
_CHECKPOINT_FOR_DOC = "Helsinki-NLP/opus-mt-en-de"
|
||||||
_CONFIG_FOR_DOC = "MarianConfig"
|
_CONFIG_FOR_DOC = "MarianConfig"
|
||||||
_TOKENIZER_FOR_DOC = "MarianTokenizer"
|
_TOKENIZER_FOR_DOC = "MarianTokenizer"
|
||||||
|
|
||||||
@@ -1125,6 +1125,7 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING)
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
input_ids: jnp.ndarray,
|
input_ids: jnp.ndarray,
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from flax.linen.attention import dot_product_attention_weights
|
|||||||
from jax import lax
|
from jax import lax
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings, replace_return_docstrings
|
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||||
from ...modeling_flax_outputs import (
|
from ...modeling_flax_outputs import (
|
||||||
FlaxBaseModelOutput,
|
FlaxBaseModelOutput,
|
||||||
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
||||||
@@ -1192,6 +1192,7 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
input_ids: jnp.ndarray,
|
input_ids: jnp.ndarray,
|
||||||
@@ -1517,36 +1518,37 @@ class FlaxMBartForConditionalGeneration(FlaxMBartPreTrainedModel):
|
|||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
|
|
||||||
FLAX_MBART_CONDITIONAL_GENERATION_DOCSTRING = """
|
FLAX_MBART_CONDITIONAL_GENERATION_DOCSTRING = r"""
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
Summarization example::
|
Summarization example::
|
||||||
|
|
||||||
>>> from transformers import MBartTokenizer, FlaxMBartForConditionalGeneration
|
>>> from transformers import MBartTokenizer, FlaxMBartForConditionalGeneration, MBartConfig
|
||||||
|
|
||||||
>>> model = FlaxMBartForConditionalGeneration.from_pretrained('facebook/mbart-large-cc25')
|
>>> model = FlaxMBartForConditionalGeneration.from_pretrained('facebook/mbart-large-cc25')
|
||||||
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-cc25')
|
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-cc25')
|
||||||
|
|
||||||
>>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
|
>>> ARTICLE_TO_SUMMARIZE = "Meine Freunde sind cool, aber sie essen zu viel Kuchen."
|
||||||
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='jax')
|
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='np')
|
||||||
|
|
||||||
>>> # Generate Summary
|
>>> # Generate Summary
|
||||||
>>> summary_ids = model.generate(inputs['input_ids'], decoder_start_token_id=tokenizer.lang_code_to_id[tgt_lang]).sequences
|
>>> summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True).sequences
|
||||||
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
|
>>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
|
||||||
|
|
||||||
Mask filling example::
|
Mask filling example::
|
||||||
|
|
||||||
>>> from transformers import MBartTokenizer, FlaxMBartForConditionalGeneration
|
>>> from transformers import MBartTokenizer, FlaxMBartForConditionalGeneration
|
||||||
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-cc25')
|
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-cc25')
|
||||||
>>> TXT = "My friends are <mask> but they eat too many carbs."
|
>>> # de_DE is the language symbol id <LID> for German
|
||||||
|
>>> TXT = "</s> Meine Freunde sind <mask> nett aber sie essen zu viel Kuchen. </s> de_DE"
|
||||||
|
|
||||||
>>> model = FlaxMBartForConditionalGeneration.from_pretrained('facebook/mbart-large-cc25')
|
>>> model = FlaxMBartForConditionalGeneration.from_pretrained('facebook/mbart-large-cc25')
|
||||||
>>> input_ids = tokenizer([TXT], return_tensors='jax')['input_ids']
|
>>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors='np')['input_ids']
|
||||||
>>> logits = model(input_ids).logits
|
>>> logits = model(input_ids).logits
|
||||||
|
|
||||||
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
|
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero()[0].item()
|
||||||
>>> probs = jax.nn.softmax(logits[0, masked_index], axis=0)
|
>>> probs = logits[0, masked_index].softmax(dim=0)
|
||||||
>>> values, predictions = jax.lax.top_k(probs)
|
>>> values, predictions = probs.topk(5)
|
||||||
|
|
||||||
>>> tokenizer.decode(predictions).split()
|
>>> tokenizer.decode(predictions).split()
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -36,13 +36,20 @@ from ...modeling_flax_outputs import (
|
|||||||
FlaxSeq2SeqLMOutput,
|
FlaxSeq2SeqLMOutput,
|
||||||
FlaxSeq2SeqModelOutput,
|
FlaxSeq2SeqModelOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
|
from ...modeling_flax_utils import (
|
||||||
|
ACT2FN,
|
||||||
|
FlaxPreTrainedModel,
|
||||||
|
append_call_sample_docstring,
|
||||||
|
append_replace_return_docstrings,
|
||||||
|
overwrite_call_docstring,
|
||||||
|
)
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .configuration_t5 import T5Config
|
from .configuration_t5 import T5Config
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
_CHECKPOINT_FOR_DOC = "t5-small"
|
||||||
_CONFIG_FOR_DOC = "T5Config"
|
_CONFIG_FOR_DOC = "T5Config"
|
||||||
_TOKENIZER_FOR_DOC = "T5Tokenizer"
|
_TOKENIZER_FOR_DOC = "T5Tokenizer"
|
||||||
|
|
||||||
@@ -844,6 +851,69 @@ T5_DECODE_INPUTS_DOCSTRING = r"""
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
T5_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
|
||||||
|
should be able to pad the inputs on both the right and the left.
|
||||||
|
|
||||||
|
Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
|
||||||
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||||
|
detail.
|
||||||
|
|
||||||
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
|
|
||||||
|
To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training
|
||||||
|
<./t5.html#training>`__.
|
||||||
|
attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
decoder_input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||||
|
Indices of decoder input sequence tokens in the vocabulary.
|
||||||
|
|
||||||
|
Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
|
||||||
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||||
|
details.
|
||||||
|
|
||||||
|
`What are decoder input IDs? <../glossary.html#decoder-input-ids>`__
|
||||||
|
|
||||||
|
T5 uses the :obj:`pad_token_id` as the starting token for :obj:`decoder_input_ids` generation. If
|
||||||
|
:obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see
|
||||||
|
:obj:`past_key_values`).
|
||||||
|
|
||||||
|
To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at `T5 Training
|
||||||
|
<./t5.html#training>`__.
|
||||||
|
decoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
||||||
|
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
|
||||||
|
also be used by default.
|
||||||
|
encoder_outputs (:obj:`tuple(tuple(jnp.ndarray)`, `optional`):
|
||||||
|
Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`:
|
||||||
|
`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
|
||||||
|
sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of
|
||||||
|
the decoder.
|
||||||
|
past_key_values (:obj:`tuple(tuple(jnp.ndarray))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
|
||||||
|
|
||||||
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||||
|
tensors for more detail.
|
||||||
|
output_hidden_states (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||||
|
more detail.
|
||||||
|
return_dict (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
|
class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
|
||||||
"""
|
"""
|
||||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
@@ -884,6 +954,7 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
|
|||||||
decoder_attention_mask,
|
decoder_attention_mask,
|
||||||
)["params"]
|
)["params"]
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
input_ids: jnp.ndarray,
|
input_ids: jnp.ndarray,
|
||||||
@@ -1155,71 +1226,6 @@ T5_START_DOCSTRING = r"""
|
|||||||
model weights.
|
model weights.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
T5_INPUTS_DOCSTRING = r"""
|
|
||||||
Args:
|
|
||||||
input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
|
|
||||||
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
|
|
||||||
should be able to pad the inputs on both the right and the left.
|
|
||||||
|
|
||||||
Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
|
|
||||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
|
||||||
detail.
|
|
||||||
|
|
||||||
`What are input IDs? <../glossary.html#input-ids>`__
|
|
||||||
|
|
||||||
To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training
|
|
||||||
<./t5.html#training>`__.
|
|
||||||
attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
|
||||||
|
|
||||||
- 1 for tokens that are **not masked**,
|
|
||||||
- 0 for tokens that are **masked**.
|
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
|
||||||
decoder_input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
|
||||||
Indices of decoder input sequence tokens in the vocabulary.
|
|
||||||
|
|
||||||
Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
|
|
||||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
|
||||||
details.
|
|
||||||
|
|
||||||
`What are decoder input IDs? <../glossary.html#decoder-input-ids>`__
|
|
||||||
|
|
||||||
T5 uses the :obj:`pad_token_id` as the starting token for :obj:`decoder_input_ids` generation. If
|
|
||||||
:obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see
|
|
||||||
:obj:`past_key_values`).
|
|
||||||
|
|
||||||
To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at `T5 Training
|
|
||||||
<./t5.html#training>`__.
|
|
||||||
decoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
|
|
||||||
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
|
|
||||||
also be used by default.
|
|
||||||
encoder_outputs (:obj:`tuple(tuple(jnp.ndarray)`, `optional`):
|
|
||||||
Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`:
|
|
||||||
`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
|
|
||||||
sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of
|
|
||||||
the decoder.
|
|
||||||
past_key_values (:obj:`tuple(tuple(jnp.ndarray))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
|
||||||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
|
||||||
|
|
||||||
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
|
||||||
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
|
||||||
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
|
||||||
|
|
||||||
use_cache (:obj:`bool`, `optional`):
|
|
||||||
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
|
||||||
decoding (see :obj:`past_key_values`).
|
|
||||||
|
|
||||||
output_attentions (:obj:`bool`, `optional`):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
|
||||||
tensors for more detail.
|
|
||||||
output_hidden_states (:obj:`bool`, `optional`):
|
|
||||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
|
||||||
more detail.
|
|
||||||
return_dict (:obj:`bool`, `optional`):
|
|
||||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
|
"The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.",
|
||||||
@@ -1252,8 +1258,6 @@ class FlaxT5Module(nn.Module):
|
|||||||
decoder_config.num_layers = self.config.num_decoder_layers
|
decoder_config.num_layers = self.config.num_decoder_layers
|
||||||
self.decoder = FlaxT5Stack(decoder_config, embed_tokens=self.shared, dtype=self.dtype)
|
self.decoder = FlaxT5Stack(decoder_config, embed_tokens=self.shared, dtype=self.dtype)
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(output_type=FlaxSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
@@ -1266,22 +1270,6 @@ class FlaxT5Module(nn.Module):
|
|||||||
return_dict=None,
|
return_dict=None,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
):
|
):
|
||||||
r"""
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example::
|
|
||||||
|
|
||||||
>>> from transformers import T5Tokenizer, FlaxT5Model
|
|
||||||
|
|
||||||
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
|
||||||
>>> model = FlaxT5Model.from_pretrained('t5-small')
|
|
||||||
|
|
||||||
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="np").input_ids # Batch size 1
|
|
||||||
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids # Batch size 1
|
|
||||||
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
|
||||||
|
|
||||||
>>> last_hidden_states = outputs.last_hidden_state
|
|
||||||
"""
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# Encode if needed (training, first prediction pass)
|
# Encode if needed (training, first prediction pass)
|
||||||
@@ -1325,6 +1313,32 @@ class FlaxT5Model(FlaxT5PreTrainedModel):
|
|||||||
module_class = FlaxT5Module
|
module_class = FlaxT5Module
|
||||||
|
|
||||||
|
|
||||||
|
append_call_sample_docstring(
|
||||||
|
FlaxT5Model, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
|
|
||||||
|
FLAX_T5_MODEL_DOCSTRING = """
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import T5Tokenizer, FlaxT5Model
|
||||||
|
|
||||||
|
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
||||||
|
>>> model = FlaxT5Model.from_pretrained('t5-small')
|
||||||
|
|
||||||
|
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="np").input_ids
|
||||||
|
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids
|
||||||
|
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||||
|
|
||||||
|
>>> last_hidden_states = outputs.last_hidden_state
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
overwrite_call_docstring(FlaxT5Model, T5_INPUTS_DOCSTRING + FLAX_T5_MODEL_DOCSTRING)
|
||||||
|
append_replace_return_docstrings(FlaxT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
|
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
|
||||||
class FlaxT5ForConditionalGenerationModule(nn.Module):
|
class FlaxT5ForConditionalGenerationModule(nn.Module):
|
||||||
config: T5Config
|
config: T5Config
|
||||||
@@ -1364,8 +1378,6 @@ class FlaxT5ForConditionalGenerationModule(nn.Module):
|
|||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
@@ -1378,24 +1390,6 @@ class FlaxT5ForConditionalGenerationModule(nn.Module):
|
|||||||
return_dict=None,
|
return_dict=None,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
):
|
):
|
||||||
r"""
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Examples::
|
|
||||||
|
|
||||||
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
|
|
||||||
|
|
||||||
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
|
||||||
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
|
|
||||||
|
|
||||||
>>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='np').input_ids
|
|
||||||
>>> decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2> </s>', return_tensors='np').input_ids
|
|
||||||
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
|
||||||
>>> logits = outputs.logits
|
|
||||||
|
|
||||||
>>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="np").input_ids
|
|
||||||
>>> outputs = model.generate(input_ids)
|
|
||||||
"""
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# Encode
|
# Encode
|
||||||
@@ -1479,7 +1473,7 @@ class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel):
|
|||||||
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
|
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
|
||||||
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
||||||
|
|
||||||
>>> text = "My friends are cool but they eat too many carbs."
|
>>> text = "summarize: My friends are cool but they eat too many carbs."
|
||||||
>>> inputs = tokenizer(text, max_length=512, return_tensors='jax')
|
>>> inputs = tokenizer(text, max_length=512, return_tensors='jax')
|
||||||
>>> encoder_outputs = model.encode(**inputs)
|
>>> encoder_outputs = model.encode(**inputs)
|
||||||
|
|
||||||
@@ -1614,3 +1608,30 @@ class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel):
|
|||||||
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
||||||
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING = """
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration
|
||||||
|
|
||||||
|
>>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')
|
||||||
|
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
|
||||||
|
|
||||||
|
>>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs."
|
||||||
|
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=512, return_tensors='jax')
|
||||||
|
|
||||||
|
>>> # Generate Summary
|
||||||
|
>>> summary_ids = model.generate(inputs['input_ids']).sequences
|
||||||
|
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
overwrite_call_docstring(
|
||||||
|
FlaxT5ForConditionalGeneration, T5_INPUTS_DOCSTRING + FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING
|
||||||
|
)
|
||||||
|
append_replace_return_docstrings(
|
||||||
|
FlaxT5ForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
|
|||||||
@@ -581,7 +581,7 @@ FLAX_VISION_CLASSIF_DOCSTRING = """
|
|||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
>>> from transformers import FlaxViTFeatureExtractor, ViTForImageClassification
|
>>> from transformers import ViTFeatureExtractor, FlaxViTForImageClassification
|
||||||
>>> from PIL import Image
|
>>> from PIL import Image
|
||||||
>>> import jax
|
>>> import jax
|
||||||
>>> import requests
|
>>> import requests
|
||||||
@@ -595,9 +595,10 @@ FLAX_VISION_CLASSIF_DOCSTRING = """
|
|||||||
>>> inputs = feature_extractor(images=image, return_tensors="jax")
|
>>> inputs = feature_extractor(images=image, return_tensors="jax")
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
>>> logits = outputs.logits
|
>>> logits = outputs.logits
|
||||||
|
|
||||||
>>> # model predicts one of the 1000 ImageNet classes
|
>>> # model predicts one of the 1000 ImageNet classes
|
||||||
>>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1)
|
>>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1)
|
||||||
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
>>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
overwrite_call_docstring(FlaxViTForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING)
|
overwrite_call_docstring(FlaxViTForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING)
|
||||||
|
|||||||
@@ -29,7 +29,12 @@ from jax import lax
|
|||||||
|
|
||||||
from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward
|
from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
|
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
|
||||||
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
|
from ...modeling_flax_utils import (
|
||||||
|
ACT2FN,
|
||||||
|
FlaxPreTrainedModel,
|
||||||
|
append_replace_return_docstrings,
|
||||||
|
overwrite_call_docstring,
|
||||||
|
)
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .configuration_wav2vec2 import Wav2Vec2Config
|
from .configuration_wav2vec2 import Wav2Vec2Config
|
||||||
|
|
||||||
@@ -853,31 +858,6 @@ class FlaxWav2Vec2Module(nn.Module):
|
|||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example::
|
|
||||||
|
|
||||||
>>> from transformers import Wav2Vec2Processor, FlaxWav2Vec2Model
|
|
||||||
>>> from datasets import load_dataset
|
|
||||||
>>> import soundfile as sf
|
|
||||||
|
|
||||||
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
|
|
||||||
>>> model = FlaxWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
|
||||||
|
|
||||||
>>> def map_to_array(batch):
|
|
||||||
>>> speech, _ = sf.read(batch["file"])
|
|
||||||
>>> batch["speech"] = speech
|
|
||||||
>>> return batch
|
|
||||||
|
|
||||||
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
|
||||||
>>> ds = ds.map(map_to_array)
|
|
||||||
|
|
||||||
>>> input_values = processor(ds["speech"][0], return_tensors="np").input_values # Batch size 1
|
|
||||||
>>> hidden_states = model(input_values).last_hidden_state
|
|
||||||
|
|
||||||
"""
|
|
||||||
extract_features = self.feature_extractor(input_values)
|
extract_features = self.feature_extractor(input_values)
|
||||||
|
|
||||||
# make sure that no loss is computed on padded inputs
|
# make sure that no loss is computed on padded inputs
|
||||||
@@ -947,6 +927,39 @@ class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel):
|
|||||||
module_class = FlaxWav2Vec2Module
|
module_class = FlaxWav2Vec2Module
|
||||||
|
|
||||||
|
|
||||||
|
FLAX_WAV2VEC2_MODEL_DOCSTRING = """
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import Wav2Vec2Processor, FlaxWav2Vec2Model
|
||||||
|
>>> from datasets import load_dataset
|
||||||
|
>>> import soundfile as sf
|
||||||
|
|
||||||
|
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-lv60")
|
||||||
|
>>> model = FlaxWav2Vec2Model.from_pretrained("facebook/wav2vec2-large-lv60")
|
||||||
|
|
||||||
|
>>> def map_to_array(batch):
|
||||||
|
>>> speech, _ = sf.read(batch["file"])
|
||||||
|
>>> batch["speech"] = speech
|
||||||
|
>>> return batch
|
||||||
|
|
||||||
|
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||||
|
>>> ds = ds.map(map_to_array)
|
||||||
|
|
||||||
|
>>> input_values = processor(ds["speech"][0], sampling_rate=16_000, return_tensors="np").input_values # Batch size 1
|
||||||
|
>>> hidden_states = model(input_values).last_hidden_state
|
||||||
|
"""
|
||||||
|
|
||||||
|
overwrite_call_docstring(
|
||||||
|
FlaxWav2Vec2Model,
|
||||||
|
WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_MODEL_DOCSTRING,
|
||||||
|
)
|
||||||
|
append_replace_return_docstrings(
|
||||||
|
FlaxWav2Vec2Model, output_type=FlaxWav2Vec2BaseModelOutput, config_class=Wav2Vec2Config
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlaxWav2Vec2ForCTCModule(nn.Module):
|
class FlaxWav2Vec2ForCTCModule(nn.Module):
|
||||||
config: Wav2Vec2Config
|
config: Wav2Vec2Config
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
@@ -970,36 +983,6 @@ class FlaxWav2Vec2ForCTCModule(nn.Module):
|
|||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
):
|
):
|
||||||
r"""
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example::
|
|
||||||
|
|
||||||
>>> import jax.numpy as jnp
|
|
||||||
>>> from transformers import Wav2Vec2Processor, FlaxWav2Vec2ForCTC
|
|
||||||
>>> from datasets import load_dataset
|
|
||||||
>>> import soundfile as sf
|
|
||||||
|
|
||||||
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
|
|
||||||
>>> model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
|
||||||
|
|
||||||
>>> def map_to_array(batch):
|
|
||||||
>>> speech, _ = sf.read(batch["file"])
|
|
||||||
>>> batch["speech"] = speech
|
|
||||||
>>> return batch
|
|
||||||
|
|
||||||
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
|
||||||
>>> ds = ds.map(map_to_array)
|
|
||||||
|
|
||||||
>>> input_values = processor(ds["speech"][0], return_tensors="np").input_values # Batch size 1
|
|
||||||
>>> logits = model(input_values).logits
|
|
||||||
>>> predicted_ids = jnp.argmax(logits, axis=-1)
|
|
||||||
|
|
||||||
>>> transcription = processor.decode(predicted_ids[0])
|
|
||||||
>>> # should give: "A MAN SAID TO THE UNIVERSE SIR I EXIST"
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
outputs = self.wav2vec2(
|
outputs = self.wav2vec2(
|
||||||
input_values,
|
input_values,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
@@ -1044,6 +1027,46 @@ class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel):
|
|||||||
module_class = FlaxWav2Vec2ForCTCModule
|
module_class = FlaxWav2Vec2ForCTCModule
|
||||||
|
|
||||||
|
|
||||||
|
FLAX_WAV2VEC2_FOR_CTC_DOCSTRING = """
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> import jax.numpy as jnp
|
||||||
|
>>> from transformers import Wav2Vec2Processor, FlaxWav2Vec2ForCTC
|
||||||
|
>>> from datasets import load_dataset
|
||||||
|
>>> import soundfile as sf
|
||||||
|
|
||||||
|
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60")
|
||||||
|
>>> model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60")
|
||||||
|
|
||||||
|
>>> def map_to_array(batch):
|
||||||
|
>>> speech, _ = sf.read(batch["file"])
|
||||||
|
>>> batch["speech"] = speech
|
||||||
|
>>> return batch
|
||||||
|
|
||||||
|
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||||
|
>>> ds = ds.map(map_to_array)
|
||||||
|
|
||||||
|
>>> input_values = processor(ds["speech"][0], sampling_rate=16_000, return_tensors="np").input_values # Batch size 1
|
||||||
|
>>> logits = model(input_values).logits
|
||||||
|
>>> predicted_ids = jnp.argmax(logits, axis=-1)
|
||||||
|
|
||||||
|
>>> transcription = processor.decode(predicted_ids[0])
|
||||||
|
>>> # should give: "A MAN SAID TO THE UNIVERSE SIR I EXIST"
|
||||||
|
"""
|
||||||
|
|
||||||
|
overwrite_call_docstring(
|
||||||
|
FlaxWav2Vec2ForCTC,
|
||||||
|
WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_CTC_DOCSTRING,
|
||||||
|
)
|
||||||
|
append_replace_return_docstrings(FlaxWav2Vec2ForCTC, output_type=FlaxCausalLMOutput, config_class=Wav2Vec2Config)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxWav2Vec2ForCTCModule(nn.Module):
|
||||||
|
config: Wav2Vec2Config
|
||||||
|
|
||||||
|
|
||||||
class FlaxWav2Vec2ForPreTrainingModule(nn.Module):
|
class FlaxWav2Vec2ForPreTrainingModule(nn.Module):
|
||||||
config: Wav2Vec2Config
|
config: Wav2Vec2Config
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
@@ -1080,43 +1103,6 @@ class FlaxWav2Vec2ForPreTrainingModule(nn.Module):
|
|||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
>>> import optax
|
|
||||||
>>> import numpy as np
|
|
||||||
>>> import jax.numpy as jnp
|
|
||||||
>>> from transformers import Wav2Vec2FeatureExtractor, FlaxWav2Vec2ForPreTraining
|
|
||||||
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
|
|
||||||
>>> from datasets import load_dataset
|
|
||||||
>>> import soundfile as sf
|
|
||||||
|
|
||||||
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("patrickvonplaten/wav2vec2-base")
|
|
||||||
>>> model = FlaxWav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
|
|
||||||
|
|
||||||
|
|
||||||
>>> def map_to_array(batch):
|
|
||||||
... speech, _ = sf.read(batch["file"])
|
|
||||||
... batch["speech"] = speech
|
|
||||||
... return batch
|
|
||||||
|
|
||||||
|
|
||||||
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
|
||||||
>>> ds = ds.map(map_to_array)
|
|
||||||
|
|
||||||
>>> input_values = feature_extractor(ds["speech"][0], return_tensors="np").input_values # Batch size 1
|
|
||||||
|
|
||||||
>>> # compute masked indices
|
|
||||||
>>> batch_size, raw_sequence_length = input_values.shape
|
|
||||||
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
|
|
||||||
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
|
|
||||||
|
|
||||||
>>> outputs = model(input_values, mask_time_indices=mask_time_indices)
|
|
||||||
|
|
||||||
>>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
|
|
||||||
>>> cosine_sim = optax.cosine_similarity(
|
|
||||||
... outputs.projected_states, outputs.projected_quantized_states, axis=-1
|
|
||||||
... )
|
|
||||||
|
|
||||||
>>> # show that cosine similarity is much higher than random
|
|
||||||
>>> assert np.asarray(cosine_sim)[mask_time_indices].mean() > 0.5
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -1222,3 +1208,60 @@ class FlaxWav2Vec2ForPreTraining(FlaxWav2Vec2PreTrainedModel):
|
|||||||
return_dict,
|
return_dict,
|
||||||
rngs=rngs,
|
rngs=rngs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING = """
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> import optax
|
||||||
|
>>> import numpy as np
|
||||||
|
>>> import jax.numpy as jnp
|
||||||
|
>>> from transformers import Wav2Vec2FeatureExtractor, FlaxWav2Vec2ForPreTraining
|
||||||
|
>>> from transformers.models.wav2vec2.modeling_flax_wav2vec2 import _compute_mask_indices
|
||||||
|
>>> from datasets import load_dataset
|
||||||
|
>>> import soundfile as sf
|
||||||
|
|
||||||
|
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large-lv60")
|
||||||
|
>>> model = FlaxWav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-lv60")
|
||||||
|
|
||||||
|
|
||||||
|
>>> def map_to_array(batch):
|
||||||
|
... speech, _ = sf.read(batch["file"])
|
||||||
|
... batch["speech"] = speech
|
||||||
|
... return batch
|
||||||
|
|
||||||
|
|
||||||
|
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||||
|
>>> ds = ds.map(map_to_array)
|
||||||
|
|
||||||
|
>>> input_values = feature_extractor(ds["speech"][0], return_tensors="np").input_values # Batch size 1
|
||||||
|
|
||||||
|
>>> # compute masked indices
|
||||||
|
>>> batch_size, raw_sequence_length = input_values.shape
|
||||||
|
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
|
||||||
|
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
|
||||||
|
|
||||||
|
>>> outputs = model(input_values, mask_time_indices=mask_time_indices)
|
||||||
|
|
||||||
|
>>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
|
||||||
|
>>> cosine_sim = optax.cosine_similarity(
|
||||||
|
... outputs.projected_states, outputs.projected_quantized_states
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> # show that cosine similarity is much higher than random
|
||||||
|
>>> assert np.asarray(cosine_sim)[mask_time_indices].mean() > 0.5
|
||||||
|
"""
|
||||||
|
|
||||||
|
overwrite_call_docstring(
|
||||||
|
FlaxWav2Vec2ForPreTraining,
|
||||||
|
WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING,
|
||||||
|
)
|
||||||
|
append_replace_return_docstrings(
|
||||||
|
FlaxWav2Vec2ForPreTraining, output_type=FlaxWav2Vec2ForPreTrainingOutput, config_class=Wav2Vec2Config
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxWav2Vec2ForCTCModule(nn.Module):
|
||||||
|
config: Wav2Vec2Config
|
||||||
|
|||||||
@@ -1183,7 +1183,7 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=Wav2Vec2ForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_values,
|
input_values,
|
||||||
@@ -1338,7 +1338,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
|
|||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=Wav2Vec2BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_values,
|
input_values,
|
||||||
@@ -1420,7 +1420,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
|||||||
self.wav2vec2.feature_extractor._freeze_parameters()
|
self.wav2vec2.feature_extractor._freeze_parameters()
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_values,
|
input_values,
|
||||||
|
|||||||
Reference in New Issue
Block a user