From fbf468b0573baddb1b9d1bb088a8b6d5c9303a7e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 4 Aug 2021 16:31:23 +0200 Subject: [PATCH] [Flax] Correct flax docs (#12782) * fix_torch_device_generate_test * remove @ * fix flax docs * correct more docs in flax * another correction * fix flax docs * Apply suggestions from code review --- docs/source/main_classes/output.rst | 90 +++++++ docs/source/model_doc/bert.rst | 3 + docs/source/model_doc/wav2vec2.rst | 16 ++ src/transformers/file_utils.py | 7 +- .../models/bart/modeling_flax_bart.py | 5 +- .../models/clip/modeling_flax_clip.py | 4 +- .../models/marian/modeling_flax_marian.py | 5 +- .../models/mbart/modeling_flax_mbart.py | 26 +- .../models/t5/modeling_flax_t5.py | 231 ++++++++++-------- .../models/vit/modeling_flax_vit.py | 5 +- .../models/wav2vec2/modeling_flax_wav2vec2.py | 229 ++++++++++------- .../models/wav2vec2/modeling_wav2vec2.py | 6 +- 12 files changed, 403 insertions(+), 224 deletions(-) diff --git a/docs/source/main_classes/output.rst b/docs/source/main_classes/output.rst index a627571f24..5d0bdc7bc6 100644 --- a/docs/source/main_classes/output.rst +++ b/docs/source/main_classes/output.rst @@ -299,3 +299,93 @@ TFSeq2SeqQuestionAnsweringModelOutput .. autoclass:: transformers.modeling_tf_outputs.TFSeq2SeqQuestionAnsweringModelOutput :members: + + +FlaxBaseModelOutput +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_flax_outputs.FlaxBaseModelOutput + + +FlaxBaseModelOutputWithPast +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPast + + +FlaxBaseModelOutputWithPooling +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPooling + + +FlaxBaseModelOutputWithPastAndCrossAttentions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPastAndCrossAttentions + + +FlaxSeq2SeqModelOutput +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_flax_outputs.FlaxSeq2SeqModelOutput + + +FlaxCausalLMOutputWithCrossAttentions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_flax_outputs.FlaxCausalLMOutputWithCrossAttentions + + +FlaxMaskedLMOutput +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_flax_outputs.FlaxMaskedLMOutput + + +FlaxSeq2SeqLMOutput +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_flax_outputs.FlaxSeq2SeqLMOutput + + +FlaxNextSentencePredictorOutput +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_flax_outputs.FlaxNextSentencePredictorOutput + + +FlaxSequenceClassifierOutput +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_flax_outputs.FlaxSequenceClassifierOutput + + +FlaxSeq2SeqSequenceClassifierOutput +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_flax_outputs.FlaxSeq2SeqSequenceClassifierOutput + + +FlaxMultipleChoiceModelOutput +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_flax_outputs.FlaxMultipleChoiceModelOutput + + +FlaxTokenClassifierOutput +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_flax_outputs.FlaxTokenClassifierOutput + + +FlaxQuestionAnsweringModelOutput +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_flax_outputs.FlaxQuestionAnsweringModelOutput + + +FlaxSeq2SeqQuestionAnsweringModelOutput +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_flax_outputs.FlaxSeq2SeqQuestionAnsweringModelOutput diff --git a/docs/source/model_doc/bert.rst b/docs/source/model_doc/bert.rst index 497f04638b..4a73599496 100644 --- a/docs/source/model_doc/bert.rst +++ b/docs/source/model_doc/bert.rst @@ -76,6 +76,9 @@ Bert specific outputs .. autoclass:: transformers.models.bert.modeling_tf_bert.TFBertForPreTrainingOutput :members: +.. autoclass:: transformers.models.bert.modeling_flax_bert.FlaxBertForPreTrainingOutput + :members: + BertModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/wav2vec2.rst b/docs/source/model_doc/wav2vec2.rst index df92a06386..e96eb80329 100644 --- a/docs/source/model_doc/wav2vec2.rst +++ b/docs/source/model_doc/wav2vec2.rst @@ -67,6 +67,22 @@ Wav2Vec2Processor :members: __call__, pad, from_pretrained, save_pretrained, batch_decode, decode, as_target_processor +Wav2Vec2 specific outputs +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2BaseModelOutput + :members: + +.. autoclass:: transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput + :members: + +.. autoclass:: transformers.models.wav2vec2.modeling_flax_wav2vec2.FlaxWav2Vec2BaseModelOutput + :members: + +.. autoclass:: transformers.models.wav2vec2.modeling_flax_wav2vec2.FlaxWav2Vec2ForPreTrainingOutput + :members: + + Wav2Vec2Model ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 2d40c5edd4..1c18fb6070 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -1063,7 +1063,7 @@ FLAX_SEQUENCE_CLASSIFICATION_SAMPLE = r""" >>> inputs = tokenizer("Hello, my dog is cute", return_tensors='jax') - >>> outputs = model(**inputs, labels=labels) + >>> outputs = model(**inputs) >>> logits = outputs.logits """ @@ -1122,9 +1122,10 @@ FLAX_CAUSAL_LM_SAMPLE = r""" >>> model = {model_class}.from_pretrained('{checkpoint}') >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") - >>> outputs = model(**inputs, labels=inputs["input_ids"]) + >>> outputs = model(**inputs) - >>> logits = outputs.logits + >>> # retrieve logts for next token + >>> next_token_logits = outputs.logits[:, -1] """ FLAX_SAMPLE_DOCSTRINGS = { diff --git a/src/transformers/models/bart/modeling_flax_bart.py b/src/transformers/models/bart/modeling_flax_bart.py index 7a1cceefe4..5bcd3f0952 100644 --- a/src/transformers/models/bart/modeling_flax_bart.py +++ b/src/transformers/models/bart/modeling_flax_bart.py @@ -30,7 +30,7 @@ from flax.linen.attention import dot_product_attention_weights from jax import lax from jax.random import PRNGKey -from ...file_utils import add_start_docstrings, replace_return_docstrings +from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...modeling_flax_outputs import ( FlaxBaseModelOutput, FlaxBaseModelOutputWithPastAndCrossAttentions, @@ -1167,6 +1167,7 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel): return outputs + @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) def __call__( self, input_ids: jnp.ndarray, @@ -1520,7 +1521,7 @@ FLAX_BART_CONDITIONAL_GENERATION_DOCSTRING = """ >>> input_ids = tokenizer([TXT], return_tensors='jax')['input_ids'] >>> logits = model(input_ids).logits - >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() + >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero()[0].item() >>> probs = jax.nn.softmax(logits[0, masked_index], axis=0) >>> values, predictions = jax.lax.top_k(probs) diff --git a/src/transformers/models/clip/modeling_flax_clip.py b/src/transformers/models/clip/modeling_flax_clip.py index e2be39da27..2285bbf1f9 100644 --- a/src/transformers/models/clip/modeling_flax_clip.py +++ b/src/transformers/models/clip/modeling_flax_clip.py @@ -941,7 +941,7 @@ FLAX_CLIP_TEXT_MODEL_DOCSTRING = """ >>> outputs = model(**inputs) >>> last_hidden_state = outputs.last_hidden_state - >>> pooled_output = outputs.pooled_output # pooled (EOS token) states + >>> pooler_output = outputs.pooler_output # pooled (EOS token) states """ overwrite_call_docstring(FlaxCLIPTextModel, CLIP_TEXT_INPUTS_DOCSTRING + FLAX_CLIP_TEXT_MODEL_DOCSTRING) @@ -997,7 +997,7 @@ FLAX_CLIP_VISION_MODEL_DOCSTRING = """ >>> outputs = model(**inputs) >>> last_hidden_state = outputs.last_hidden_state - >>> pooled_output = outputs.pooled_output # pooled CLS states + >>> pooler_output = outputs.pooler_output # pooled CLS states """ overwrite_call_docstring(FlaxCLIPVisionModel, CLIP_VISION_INPUTS_DOCSTRING + FLAX_CLIP_VISION_MODEL_DOCSTRING) diff --git a/src/transformers/models/marian/modeling_flax_marian.py b/src/transformers/models/marian/modeling_flax_marian.py index ccbdcab8aa..e8a744c5f4 100644 --- a/src/transformers/models/marian/modeling_flax_marian.py +++ b/src/transformers/models/marian/modeling_flax_marian.py @@ -30,7 +30,7 @@ from flax.linen.attention import dot_product_attention_weights from jax import lax from jax.random import PRNGKey -from ...file_utils import add_start_docstrings, replace_return_docstrings +from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...modeling_flax_outputs import ( FlaxBaseModelOutput, FlaxBaseModelOutputWithPastAndCrossAttentions, @@ -45,7 +45,7 @@ from .configuration_marian import MarianConfig logger = logging.get_logger(__name__) -_CHECKPOINT_FOR_DOC = "Helsinki-NLP/opus-mt-en-de'" +_CHECKPOINT_FOR_DOC = "Helsinki-NLP/opus-mt-en-de" _CONFIG_FOR_DOC = "MarianConfig" _TOKENIZER_FOR_DOC = "MarianTokenizer" @@ -1125,6 +1125,7 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel): return outputs + @add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING) def __call__( self, input_ids: jnp.ndarray, diff --git a/src/transformers/models/mbart/modeling_flax_mbart.py b/src/transformers/models/mbart/modeling_flax_mbart.py index 5673e06afe..a4577d9835 100644 --- a/src/transformers/models/mbart/modeling_flax_mbart.py +++ b/src/transformers/models/mbart/modeling_flax_mbart.py @@ -30,7 +30,7 @@ from flax.linen.attention import dot_product_attention_weights from jax import lax from jax.random import PRNGKey -from ...file_utils import add_start_docstrings, replace_return_docstrings +from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...modeling_flax_outputs import ( FlaxBaseModelOutput, FlaxBaseModelOutputWithPastAndCrossAttentions, @@ -1192,6 +1192,7 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel): return outputs + @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING) def __call__( self, input_ids: jnp.ndarray, @@ -1517,36 +1518,37 @@ class FlaxMBartForConditionalGeneration(FlaxMBartPreTrainedModel): return model_kwargs -FLAX_MBART_CONDITIONAL_GENERATION_DOCSTRING = """ +FLAX_MBART_CONDITIONAL_GENERATION_DOCSTRING = r""" Returns: Summarization example:: - >>> from transformers import MBartTokenizer, FlaxMBartForConditionalGeneration + >>> from transformers import MBartTokenizer, FlaxMBartForConditionalGeneration, MBartConfig >>> model = FlaxMBartForConditionalGeneration.from_pretrained('facebook/mbart-large-cc25') >>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-cc25') - >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='jax') + >>> ARTICLE_TO_SUMMARIZE = "Meine Freunde sind cool, aber sie essen zu viel Kuchen." + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='np') >>> # Generate Summary - >>> summary_ids = model.generate(inputs['input_ids'], decoder_start_token_id=tokenizer.lang_code_to_id[tgt_lang]).sequences - >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) + >>> summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True).sequences + >>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]) Mask filling example:: >>> from transformers import MBartTokenizer, FlaxMBartForConditionalGeneration >>> tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-cc25') - >>> TXT = "My friends are but they eat too many carbs." + >>> # de_DE is the language symbol id for German + >>> TXT = " Meine Freunde sind nett aber sie essen zu viel Kuchen. de_DE" >>> model = FlaxMBartForConditionalGeneration.from_pretrained('facebook/mbart-large-cc25') - >>> input_ids = tokenizer([TXT], return_tensors='jax')['input_ids'] + >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors='np')['input_ids'] >>> logits = model(input_ids).logits - >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() - >>> probs = jax.nn.softmax(logits[0, masked_index], axis=0) - >>> values, predictions = jax.lax.top_k(probs) + >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero()[0].item() + >>> probs = logits[0, masked_index].softmax(dim=0) + >>> values, predictions = probs.topk(5) >>> tokenizer.decode(predictions).split() """ diff --git a/src/transformers/models/t5/modeling_flax_t5.py b/src/transformers/models/t5/modeling_flax_t5.py index dfaf0976e3..5f493b5767 100644 --- a/src/transformers/models/t5/modeling_flax_t5.py +++ b/src/transformers/models/t5/modeling_flax_t5.py @@ -36,13 +36,20 @@ from ...modeling_flax_outputs import ( FlaxSeq2SeqLMOutput, FlaxSeq2SeqModelOutput, ) -from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, +) from ...utils import logging from .configuration_t5 import T5Config logger = logging.get_logger(__name__) +_CHECKPOINT_FOR_DOC = "t5-small" _CONFIG_FOR_DOC = "T5Config" _TOKENIZER_FOR_DOC = "T5Tokenizer" @@ -844,6 +851,69 @@ T5_DECODE_INPUTS_DOCSTRING = r""" """ +T5_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using :class:`~transformers.T5Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + detail. + + `What are input IDs? <../glossary.html#input-ids>`__ + + To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training + <./t5.html#training>`__. + attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + decoder_input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.T5Tokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are decoder input IDs? <../glossary.html#decoder-input-ids>`__ + + T5 uses the :obj:`pad_token_id` as the starting token for :obj:`decoder_input_ids` generation. If + :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see + :obj:`past_key_values`). + + To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at `T5 Training + <./t5.html#training>`__. + decoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will + also be used by default. + encoder_outputs (:obj:`tuple(tuple(jnp.ndarray)`, `optional`): + Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`: + `attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a + sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of + the decoder. + past_key_values (:obj:`tuple(tuple(jnp.ndarray))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + + + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + class FlaxT5PreTrainedModel(FlaxPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -884,6 +954,7 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel): decoder_attention_mask, )["params"] + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) def __call__( self, input_ids: jnp.ndarray, @@ -1155,71 +1226,6 @@ T5_START_DOCSTRING = r""" model weights. """ -T5_INPUTS_DOCSTRING = r""" - Args: - input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you - should be able to pad the inputs on both the right and the left. - - Indices can be obtained using :class:`~transformers.T5Tokenizer`. See - :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for - detail. - - `What are input IDs? <../glossary.html#input-ids>`__ - - To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training - <./t5.html#training>`__. - attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - `What are attention masks? <../glossary.html#attention-mask>`__ - decoder_input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): - Indices of decoder input sequence tokens in the vocabulary. - - Indices can be obtained using :class:`~transformers.T5Tokenizer`. See - :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for - details. - - `What are decoder input IDs? <../glossary.html#decoder-input-ids>`__ - - T5 uses the :obj:`pad_token_id` as the starting token for :obj:`decoder_input_ids` generation. If - :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see - :obj:`past_key_values`). - - To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at `T5 Training - <./t5.html#training>`__. - decoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): - Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will - also be used by default. - encoder_outputs (:obj:`tuple(tuple(jnp.ndarray)`, `optional`): - Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`: - `attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a - sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of - the decoder. - past_key_values (:obj:`tuple(tuple(jnp.ndarray))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - - If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` - (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` - instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. - - use_cache (:obj:`bool`, `optional`): - If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up - decoding (see :obj:`past_key_values`). - - output_attentions (:obj:`bool`, `optional`): - Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned - tensors for more detail. - output_hidden_states (:obj:`bool`, `optional`): - Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for - more detail. - return_dict (:obj:`bool`, `optional`): - Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. -""" - @add_start_docstrings( "The bare T5 Model transformer outputting raw hidden-states" "without any specific head on top.", @@ -1252,8 +1258,6 @@ class FlaxT5Module(nn.Module): decoder_config.num_layers = self.config.num_decoder_layers self.decoder = FlaxT5Stack(decoder_config, embed_tokens=self.shared, dtype=self.dtype) - @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) def __call__( self, input_ids=None, @@ -1266,22 +1270,6 @@ class FlaxT5Module(nn.Module): return_dict=None, deterministic: bool = True, ): - r""" - Returns: - - Example:: - - >>> from transformers import T5Tokenizer, FlaxT5Model - - >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') - >>> model = FlaxT5Model.from_pretrained('t5-small') - - >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="np").input_ids # Batch size 1 - >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids # Batch size 1 - >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) - - >>> last_hidden_states = outputs.last_hidden_state - """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Encode if needed (training, first prediction pass) @@ -1325,6 +1313,32 @@ class FlaxT5Model(FlaxT5PreTrainedModel): module_class = FlaxT5Module +append_call_sample_docstring( + FlaxT5Model, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC +) + +FLAX_T5_MODEL_DOCSTRING = """ + Returns: + + Example:: + + >>> from transformers import T5Tokenizer, FlaxT5Model + + >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + >>> model = FlaxT5Model.from_pretrained('t5-small') + + >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="np").input_ids + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + + >>> last_hidden_states = outputs.last_hidden_state +""" + + +overwrite_call_docstring(FlaxT5Model, T5_INPUTS_DOCSTRING + FLAX_T5_MODEL_DOCSTRING) +append_replace_return_docstrings(FlaxT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + + @add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING) class FlaxT5ForConditionalGenerationModule(nn.Module): config: T5Config @@ -1364,8 +1378,6 @@ class FlaxT5ForConditionalGenerationModule(nn.Module): dtype=self.dtype, ) - @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) def __call__( self, input_ids=None, @@ -1378,24 +1390,6 @@ class FlaxT5ForConditionalGenerationModule(nn.Module): return_dict=None, deterministic: bool = True, ): - r""" - Returns: - - Examples:: - - >>> from transformers import T5Tokenizer, T5ForConditionalGeneration - - >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') - >>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small') - - >>> input_ids = tokenizer('The walks in park', return_tensors='np').input_ids - >>> decoder_input_ids = tokenizer(' cute dog the ', return_tensors='np').input_ids - >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) - >>> logits = outputs.logits - - >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="np").input_ids - >>> outputs = model.generate(input_ids) - """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Encode @@ -1479,7 +1473,7 @@ class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel): >>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small') >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') - >>> text = "My friends are cool but they eat too many carbs." + >>> text = "summarize: My friends are cool but they eat too many carbs." >>> inputs = tokenizer(text, max_length=512, return_tensors='jax') >>> encoder_outputs = model.encode(**inputs) @@ -1614,3 +1608,30 @@ class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel): def update_inputs_for_generation(self, model_outputs, model_kwargs): model_kwargs["past_key_values"] = model_outputs.past_key_values return model_kwargs + + +FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING = """ + Returns: + + Example:: + + >>> from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration + + >>> model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small') + >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') + + >>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs." + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=512, return_tensors='jax') + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs['input_ids']).sequences + >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)) +""" + + +overwrite_call_docstring( + FlaxT5ForConditionalGeneration, T5_INPUTS_DOCSTRING + FLAX_T5_CONDITIONAL_GENERATION_DOCSTRING +) +append_replace_return_docstrings( + FlaxT5ForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC +) diff --git a/src/transformers/models/vit/modeling_flax_vit.py b/src/transformers/models/vit/modeling_flax_vit.py index 7ce86664e3..7b448da8b0 100644 --- a/src/transformers/models/vit/modeling_flax_vit.py +++ b/src/transformers/models/vit/modeling_flax_vit.py @@ -581,7 +581,7 @@ FLAX_VISION_CLASSIF_DOCSTRING = """ Example:: - >>> from transformers import FlaxViTFeatureExtractor, ViTForImageClassification + >>> from transformers import ViTFeatureExtractor, FlaxViTForImageClassification >>> from PIL import Image >>> import jax >>> import requests @@ -595,9 +595,10 @@ FLAX_VISION_CLASSIF_DOCSTRING = """ >>> inputs = feature_extractor(images=image, return_tensors="jax") >>> outputs = model(**inputs) >>> logits = outputs.logits + >>> # model predicts one of the 1000 ImageNet classes >>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1) - >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + >>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()]) """ overwrite_call_docstring(FlaxViTForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING) diff --git a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py index e95e21f909..34281c0068 100644 --- a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py @@ -29,7 +29,12 @@ from jax import lax from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput -from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_replace_return_docstrings, + overwrite_call_docstring, +) from ...utils import logging from .configuration_wav2vec2 import Wav2Vec2Config @@ -853,31 +858,6 @@ class FlaxWav2Vec2Module(nn.Module): output_hidden_states=None, return_dict=None, ): - """ - - Returns: - - Example:: - - >>> from transformers import Wav2Vec2Processor, FlaxWav2Vec2Model - >>> from datasets import load_dataset - >>> import soundfile as sf - - >>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") - >>> model = FlaxWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") - - >>> def map_to_array(batch): - >>> speech, _ = sf.read(batch["file"]) - >>> batch["speech"] = speech - >>> return batch - - >>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") - >>> ds = ds.map(map_to_array) - - >>> input_values = processor(ds["speech"][0], return_tensors="np").input_values # Batch size 1 - >>> hidden_states = model(input_values).last_hidden_state - - """ extract_features = self.feature_extractor(input_values) # make sure that no loss is computed on padded inputs @@ -947,6 +927,39 @@ class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel): module_class = FlaxWav2Vec2Module +FLAX_WAV2VEC2_MODEL_DOCSTRING = """ + Returns: + + Example:: + + >>> from transformers import Wav2Vec2Processor, FlaxWav2Vec2Model + >>> from datasets import load_dataset + >>> import soundfile as sf + + >>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-lv60") + >>> model = FlaxWav2Vec2Model.from_pretrained("facebook/wav2vec2-large-lv60") + + >>> def map_to_array(batch): + >>> speech, _ = sf.read(batch["file"]) + >>> batch["speech"] = speech + >>> return batch + + >>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.map(map_to_array) + + >>> input_values = processor(ds["speech"][0], sampling_rate=16_000, return_tensors="np").input_values # Batch size 1 + >>> hidden_states = model(input_values).last_hidden_state +""" + +overwrite_call_docstring( + FlaxWav2Vec2Model, + WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_MODEL_DOCSTRING, +) +append_replace_return_docstrings( + FlaxWav2Vec2Model, output_type=FlaxWav2Vec2BaseModelOutput, config_class=Wav2Vec2Config +) + + class FlaxWav2Vec2ForCTCModule(nn.Module): config: Wav2Vec2Config dtype: jnp.dtype = jnp.float32 @@ -970,36 +983,6 @@ class FlaxWav2Vec2ForCTCModule(nn.Module): output_hidden_states=None, return_dict=None, ): - r""" - Returns: - - Example:: - - >>> import jax.numpy as jnp - >>> from transformers import Wav2Vec2Processor, FlaxWav2Vec2ForCTC - >>> from datasets import load_dataset - >>> import soundfile as sf - - >>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") - >>> model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") - - >>> def map_to_array(batch): - >>> speech, _ = sf.read(batch["file"]) - >>> batch["speech"] = speech - >>> return batch - - >>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") - >>> ds = ds.map(map_to_array) - - >>> input_values = processor(ds["speech"][0], return_tensors="np").input_values # Batch size 1 - >>> logits = model(input_values).logits - >>> predicted_ids = jnp.argmax(logits, axis=-1) - - >>> transcription = processor.decode(predicted_ids[0]) - >>> # should give: "A MAN SAID TO THE UNIVERSE SIR I EXIST" - - """ - outputs = self.wav2vec2( input_values, attention_mask=attention_mask, @@ -1044,6 +1027,46 @@ class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel): module_class = FlaxWav2Vec2ForCTCModule +FLAX_WAV2VEC2_FOR_CTC_DOCSTRING = """ + Returns: + + Example:: + + >>> import jax.numpy as jnp + >>> from transformers import Wav2Vec2Processor, FlaxWav2Vec2ForCTC + >>> from datasets import load_dataset + >>> import soundfile as sf + + >>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60") + >>> model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60") + + >>> def map_to_array(batch): + >>> speech, _ = sf.read(batch["file"]) + >>> batch["speech"] = speech + >>> return batch + + >>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.map(map_to_array) + + >>> input_values = processor(ds["speech"][0], sampling_rate=16_000, return_tensors="np").input_values # Batch size 1 + >>> logits = model(input_values).logits + >>> predicted_ids = jnp.argmax(logits, axis=-1) + + >>> transcription = processor.decode(predicted_ids[0]) + >>> # should give: "A MAN SAID TO THE UNIVERSE SIR I EXIST" +""" + +overwrite_call_docstring( + FlaxWav2Vec2ForCTC, + WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_CTC_DOCSTRING, +) +append_replace_return_docstrings(FlaxWav2Vec2ForCTC, output_type=FlaxCausalLMOutput, config_class=Wav2Vec2Config) + + +class FlaxWav2Vec2ForCTCModule(nn.Module): + config: Wav2Vec2Config + + class FlaxWav2Vec2ForPreTrainingModule(nn.Module): config: Wav2Vec2Config dtype: jnp.dtype = jnp.float32 @@ -1080,43 +1103,6 @@ class FlaxWav2Vec2ForPreTrainingModule(nn.Module): Example:: - >>> import optax - >>> import numpy as np - >>> import jax.numpy as jnp - >>> from transformers import Wav2Vec2FeatureExtractor, FlaxWav2Vec2ForPreTraining - >>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices - >>> from datasets import load_dataset - >>> import soundfile as sf - - >>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("patrickvonplaten/wav2vec2-base") - >>> model = FlaxWav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base") - - - >>> def map_to_array(batch): - ... speech, _ = sf.read(batch["file"]) - ... batch["speech"] = speech - ... return batch - - - >>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") - >>> ds = ds.map(map_to_array) - - >>> input_values = feature_extractor(ds["speech"][0], return_tensors="np").input_values # Batch size 1 - - >>> # compute masked indices - >>> batch_size, raw_sequence_length = input_values.shape - >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length) - >>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2) - - >>> outputs = model(input_values, mask_time_indices=mask_time_indices) - - >>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states) - >>> cosine_sim = optax.cosine_similarity( - ... outputs.projected_states, outputs.projected_quantized_states, axis=-1 - ... ) - - >>> # show that cosine similarity is much higher than random - >>> assert np.asarray(cosine_sim)[mask_time_indices].mean() > 0.5 """ @@ -1222,3 +1208,60 @@ class FlaxWav2Vec2ForPreTraining(FlaxWav2Vec2PreTrainedModel): return_dict, rngs=rngs, ) + + +FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING = """ + Returns: + + Example:: + + >>> import optax + >>> import numpy as np + >>> import jax.numpy as jnp + >>> from transformers import Wav2Vec2FeatureExtractor, FlaxWav2Vec2ForPreTraining + >>> from transformers.models.wav2vec2.modeling_flax_wav2vec2 import _compute_mask_indices + >>> from datasets import load_dataset + >>> import soundfile as sf + + >>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large-lv60") + >>> model = FlaxWav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-lv60") + + + >>> def map_to_array(batch): + ... speech, _ = sf.read(batch["file"]) + ... batch["speech"] = speech + ... return batch + + + >>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.map(map_to_array) + + >>> input_values = feature_extractor(ds["speech"][0], return_tensors="np").input_values # Batch size 1 + + >>> # compute masked indices + >>> batch_size, raw_sequence_length = input_values.shape + >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length) + >>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2) + + >>> outputs = model(input_values, mask_time_indices=mask_time_indices) + + >>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states) + >>> cosine_sim = optax.cosine_similarity( + ... outputs.projected_states, outputs.projected_quantized_states + ... ) + + >>> # show that cosine similarity is much higher than random + >>> assert np.asarray(cosine_sim)[mask_time_indices].mean() > 0.5 +""" + +overwrite_call_docstring( + FlaxWav2Vec2ForPreTraining, + WAV_2_VEC_2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING, +) +append_replace_return_docstrings( + FlaxWav2Vec2ForPreTraining, output_type=FlaxWav2Vec2ForPreTrainingOutput, config_class=Wav2Vec2Config +) + + +class FlaxWav2Vec2ForCTCModule(nn.Module): + config: Wav2Vec2Config diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index a15ae444a5..9454f5b005 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -1183,7 +1183,7 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel): return logits @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=Wav2Vec2ForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_values, @@ -1338,7 +1338,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel): self.init_weights() @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=Wav2Vec2BaseModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_values, @@ -1420,7 +1420,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): self.wav2vec2.feature_extractor._freeze_parameters() @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC) def forward( self, input_values,