diff --git a/examples/research_projects/jax-projects/hybrid_clip/modeling_hybrid_clip.py b/examples/research_projects/jax-projects/hybrid_clip/modeling_hybrid_clip.py index af4786eaf1..df7693ef0b 100644 --- a/examples/research_projects/jax-projects/hybrid_clip/modeling_hybrid_clip.py +++ b/examples/research_projects/jax-projects/hybrid_clip/modeling_hybrid_clip.py @@ -224,7 +224,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel): `What are input IDs? <../glossary.html#input-ids>`__ Returns: - text_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The text embeddings + text_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The text embeddings obtained by applying the projection layer to the pooled output of text model. """ if position_ids is None: @@ -273,7 +273,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel): :meth:`transformers.ImageFeatureExtractionMixin.__call__` for details. Returns: - image_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The image embeddings + image_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The image embeddings obtained by applying the projection layer to the pooled output of vision model. """ diff --git a/src/transformers/generation_flax_logits_process.py b/src/transformers/generation_flax_logits_process.py index c6179e63fc..f076839c39 100644 --- a/src/transformers/generation_flax_logits_process.py +++ b/src/transformers/generation_flax_logits_process.py @@ -19,7 +19,6 @@ from abc import ABC import jax import jax.lax as lax import jax.numpy as jnp -import jaxlib.xla_extension as jax_xla from .file_utils import add_start_docstrings from .utils.logging import get_logger @@ -30,7 +29,7 @@ logger = get_logger(__name__) LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" Args: - input_ids (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`): + input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See @@ -38,14 +37,14 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" details. `What are input IDs? <../glossary.html#input-ids>`__ - scores (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.vocab_size)`): + scores (:obj:`jnp.ndarray` of shape :obj:`(batch_size, config.vocab_size)`): Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam search or log softmax for each vocabulary token when using beam search kwargs: Additional logits processor specific kwargs. Return: - :obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores. + :obj:`jnp.ndarray` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores. """ @@ -54,7 +53,7 @@ class FlaxLogitsProcessor(ABC): """Abstract base class for all logit processors that can be applied during generation.""" @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray: + def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray: """Flax method for processing logits.""" raise NotImplementedError( f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." @@ -65,7 +64,7 @@ class FlaxLogitsWarper(ABC): """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray: + def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray: """Flax method for warping logits.""" raise NotImplementedError( f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." @@ -81,9 +80,7 @@ class FlaxLogitsProcessorList(list): """ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__( - self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int, **kwargs - ) -> jax_xla.DeviceArray: + def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int, **kwargs) -> jnp.ndarray: for processor in self: function_args = inspect.signature(processor.__call__).parameters if len(function_args) > 3: @@ -111,9 +108,7 @@ class FlaxTemperatureLogitsWarper(FlaxLogitsWarper): self.temperature = temperature - def __call__( - self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int - ) -> jax_xla.DeviceArray: + def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: scores = scores / self.temperature return scores @@ -141,9 +136,7 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper): self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep - def __call__( - self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int - ) -> jax_xla.DeviceArray: + def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1]) mask_scores = jnp.full_like(scores, self.filter_value) @@ -183,9 +176,7 @@ class FlaxTopKLogitsWarper(FlaxLogitsWarper): self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep - def __call__( - self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int - ) -> jax_xla.DeviceArray: + def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: batch_size, vocab_size = scores.shape next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value) @@ -212,9 +203,7 @@ class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor): def __init__(self, bos_token_id: int): self.bos_token_id = bos_token_id - def __call__( - self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int - ) -> jax_xla.DeviceArray: + def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: new_scores = jnp.full(scores.shape, -float("inf")) apply_penalty = 1 - jnp.bool_(cur_len - 1) @@ -242,9 +231,7 @@ class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor): self.max_length = max_length self.eos_token_id = eos_token_id - def __call__( - self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int - ) -> jax_xla.DeviceArray: + def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: new_scores = jnp.full(scores.shape, -float("inf")) apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1) @@ -277,9 +264,7 @@ class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor): self.min_length = min_length self.eos_token_id = eos_token_id - def __call__( - self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int - ) -> jax_xla.DeviceArray: + def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: # create boolean flag to decide if min length penalty should be applied apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1) diff --git a/src/transformers/generation_flax_utils.py b/src/transformers/generation_flax_utils.py index 3d868d4c9d..47d2c5035c 100644 --- a/src/transformers/generation_flax_utils.py +++ b/src/transformers/generation_flax_utils.py @@ -23,7 +23,6 @@ import numpy as np import flax import jax import jax.numpy as jnp -import jaxlib.xla_extension as jax_xla from jax import lax from .file_utils import ModelOutput @@ -49,11 +48,11 @@ class FlaxGreedySearchOutput(ModelOutput): Args: - sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`): + sequences (:obj:`jnp.ndarray` of shape :obj:`(batch_size, max_length)`): The generated sequences. """ - sequences: jax_xla.DeviceArray = None + sequences: jnp.ndarray = None @flax.struct.dataclass @@ -63,11 +62,11 @@ class FlaxSampleOutput(ModelOutput): Args: - sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`): + sequences (:obj:`jnp.ndarray` of shape :obj:`(batch_size, max_length)`): The generated sequences. """ - sequences: jax_xla.DeviceArray = None + sequences: jnp.ndarray = None @flax.struct.dataclass @@ -77,44 +76,44 @@ class FlaxBeamSearchOutput(ModelOutput): Args: - sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`): + sequences (:obj:`jnp.ndarray` of shape :obj:`(batch_size, max_length)`): The generated sequences. - scores (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size,)`): + scores (:obj:`jnp.ndarray` of shape :obj:`(batch_size,)`): The scores (log probabilites) of the generated sequences. """ - sequences: jax_xla.DeviceArray = None - scores: jax_xla.DeviceArray = None + sequences: jnp.ndarray = None + scores: jnp.ndarray = None @flax.struct.dataclass class GreedyState: - cur_len: jax_xla.DeviceArray - sequences: jax_xla.DeviceArray - running_token: jax_xla.DeviceArray - is_sent_finished: jax_xla.DeviceArray - model_kwargs: Dict[str, jax_xla.DeviceArray] + cur_len: jnp.ndarray + sequences: jnp.ndarray + running_token: jnp.ndarray + is_sent_finished: jnp.ndarray + model_kwargs: Dict[str, jnp.ndarray] @flax.struct.dataclass class SampleState: - cur_len: jax_xla.DeviceArray - sequences: jax_xla.DeviceArray - running_token: jax_xla.DeviceArray - is_sent_finished: jax_xla.DeviceArray - prng_key: jax_xla.DeviceArray - model_kwargs: Dict[str, jax_xla.DeviceArray] + cur_len: jnp.ndarray + sequences: jnp.ndarray + running_token: jnp.ndarray + is_sent_finished: jnp.ndarray + prng_key: jnp.ndarray + model_kwargs: Dict[str, jnp.ndarray] @flax.struct.dataclass class BeamSearchState: - cur_len: jax_xla.DeviceArray - running_sequences: jax_xla.DeviceArray - running_scores: jax_xla.DeviceArray - sequences: jax_xla.DeviceArray - scores: jax_xla.DeviceArray - is_sent_finished: jax_xla.DeviceArray - model_kwargs: Dict[str, jax_xla.DeviceArray] + cur_len: jnp.ndarray + running_sequences: jnp.ndarray + running_scores: jnp.ndarray + sequences: jnp.ndarray + scores: jnp.ndarray + is_sent_finished: jnp.ndarray + model_kwargs: Dict[str, jnp.ndarray] class FlaxGenerationMixin: @@ -156,14 +155,14 @@ class FlaxGenerationMixin: def generate( self, - input_ids: jax_xla.DeviceArray, + input_ids: jnp.ndarray, max_length: Optional[int] = None, pad_token_id: Optional[int] = None, bos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, decoder_start_token_id: Optional[int] = None, do_sample: Optional[bool] = None, - prng_key: Optional[jax_xla.DeviceArray] = None, + prng_key: Optional[jnp.ndarray] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None, @@ -175,7 +174,7 @@ class FlaxGenerationMixin: length_penalty: Optional[float] = None, early_stopping: Optional[bool] = None, trace: bool = True, - params: Optional[Dict[str, jax_xla.DeviceArray]] = None, + params: Optional[Dict[str, jnp.ndarray]] = None, **model_kwargs, ): r""" @@ -191,7 +190,7 @@ class FlaxGenerationMixin: Parameters: - input_ids (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`, `optional`): + input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): The sequence used as a prompt for the generation. max_length (:obj:`int`, `optional`, defaults to 20): The maximum length of the sequence to be generated. @@ -217,7 +216,7 @@ class FlaxGenerationMixin: trace (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to a considerably slower runtime. - params (:obj:`Dict[str, jax_xla.DeviceArray]`, `optional`): + params (:obj:`Dict[str, jnp.ndarray]`, `optional`): Optionally the model parameters can be passed. Can be useful for parallelized generation. model_kwargs: Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. @@ -395,8 +394,8 @@ class FlaxGenerationMixin: eos_token_id: Optional[int] = None, logits_processor: Optional[FlaxLogitsProcessorList] = None, trace: bool = True, - params: Optional[Dict[str, jax_xla.DeviceArray]] = None, - model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None, + params: Optional[Dict[str, jnp.ndarray]] = None, + model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, ): # init values max_length = max_length if max_length is not None else self.config.max_length @@ -479,12 +478,12 @@ class FlaxGenerationMixin: max_length: Optional[int] = None, pad_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, - prng_key: Optional[jax_xla.DeviceArray] = None, + prng_key: Optional[jnp.ndarray] = None, logits_processor: Optional[FlaxLogitsProcessorList] = None, logits_warper: Optional[FlaxLogitsProcessorList] = None, trace: bool = True, - params: Optional[Dict[str, jax_xla.DeviceArray]] = None, - model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None, + params: Optional[Dict[str, jnp.ndarray]] = None, + model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, ): # init values max_length = max_length if max_length is not None else self.config.max_length @@ -580,8 +579,8 @@ class FlaxGenerationMixin: early_stopping: Optional[bool] = None, logits_processor: Optional[FlaxLogitsProcessorList] = None, trace: bool = True, - params: Optional[Dict[str, jax_xla.DeviceArray]] = None, - model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None, + params: Optional[Dict[str, jnp.ndarray]] = None, + model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, ): """ This beam search function is heavily inspired by Flax's official example: diff --git a/src/transformers/modeling_flax_outputs.py b/src/transformers/modeling_flax_outputs.py index b2929ee134..c748a4f72e 100644 --- a/src/transformers/modeling_flax_outputs.py +++ b/src/transformers/modeling_flax_outputs.py @@ -14,7 +14,7 @@ from typing import Dict, Optional, Tuple import flax -import jaxlib.xla_extension as jax_xla +import jax.numpy as jnp from .file_utils import ModelOutput @@ -25,24 +25,24 @@ class FlaxBaseModelOutput(ModelOutput): Base class for model's outputs, with potential hidden states and attentions. Args: - last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. - hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each - layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ - last_hidden_state: jax_xla.DeviceArray = None - hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None - attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + last_hidden_state: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None @flax.struct.dataclass @@ -51,28 +51,28 @@ class FlaxBaseModelOutputWithPast(ModelOutput): Base class for model's outputs, with potential hidden states and attentions. Args: - last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. - past_key_values (:obj:`Dict[str, jax_xla.DeviceArray]`): + past_key_values (:obj:`Dict[str, jnp.ndarray]`): Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`. - hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each - layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ - last_hidden_state: jax_xla.DeviceArray = None - past_key_values: Optional[Dict[str, jax_xla.DeviceArray]] = None - hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None - attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + last_hidden_state: jnp.ndarray = None + past_key_values: Optional[Dict[str, jnp.ndarray]] = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None @flax.struct.dataclass @@ -81,29 +81,29 @@ class FlaxBaseModelOutputWithPooling(ModelOutput): Base class for model's outputs that also contains a pooling of the last hidden states. Args: - last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. - pooler_output (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, hidden_size)`): + pooler_output (:obj:`jnp.ndarray` of shape :obj:`(batch_size, hidden_size)`): Last layer hidden-state of the first token of the sequence (classification token) further processed by a Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence prediction (classification) objective during pretraining. - hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each - layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ - last_hidden_state: jax_xla.DeviceArray = None - pooler_output: jax_xla.DeviceArray = None - hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None - attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + last_hidden_state: jnp.ndarray = None + pooler_output: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None @flax.struct.dataclass @@ -112,44 +112,44 @@ class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput): Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). Args: - last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output. - past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): - Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2 - tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + Tuple of :obj:`tuple(jnp.ndarray)` of length :obj:`config.n_layers`, with each tuple having 2 tensors of + shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if ``config.is_encoder_decoder=True`` 2 additional tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if ``config.is_encoder_decoder=True`` in the cross-attention blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding. - hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each - layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` and ``config.add_cross_attention=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` and ``config.add_cross_attention=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads. """ - last_hidden_state: jax_xla.DeviceArray = None - past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None - hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None - attentions: Optional[Tuple[jax_xla.DeviceArray]] = None - cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + last_hidden_state: jnp.ndarray = None + past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + cross_attentions: Optional[Tuple[jnp.ndarray]] = None @flax.struct.dataclass @@ -159,58 +159,58 @@ class FlaxSeq2SeqModelOutput(ModelOutput): decoding. Args: - last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the decoder of the model. If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output. - past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): - Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2 - tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional - tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + Tuple of :obj:`tuple(jnp.ndarray)` of length :obj:`config.n_layers`, with each tuple having 2 tensors of + shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each - layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + decoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + decoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. - cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads. - encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + encoder_last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each - layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + encoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + encoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the self-attention heads. """ - last_hidden_state: jax_xla.DeviceArray = None - past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None - decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None - decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None - cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None - encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None - encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None - encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + last_hidden_state: jnp.ndarray = None + past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None + decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None + decoder_attentions: Optional[Tuple[jnp.ndarray]] = None + cross_attentions: Optional[Tuple[jnp.ndarray]] = None + encoder_last_hidden_state: Optional[jnp.ndarray] = None + encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None + encoder_attentions: Optional[Tuple[jnp.ndarray]] = None @flax.struct.dataclass @@ -219,39 +219,39 @@ class FlaxCausalLMOutputWithCrossAttentions(ModelOutput): Base class for causal language model (or autoregressive) outputs. Args: - logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each - layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Cross attentions weights after the attention softmax, used to compute the weighted average in the cross-attention heads. - past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): - Tuple of :obj:`jax_xla.DeviceArray` tuples of length :obj:`config.n_layers`, with each tuple containing the - cached key, value states of the self-attention and the cross-attention layers if model is used in - encoder-decoder setting. Only relevant if ``config.is_decoder = True``. + past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + Tuple of :obj:`jnp.ndarray` tuples of length :obj:`config.n_layers`, with each tuple containing the cached + key, value states of the self-attention and the cross-attention layers if model is used in encoder-decoder + setting. Only relevant if ``config.is_decoder = True``. Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding. """ - logits: jax_xla.DeviceArray = None - past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None - hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None - attentions: Optional[Tuple[jax_xla.DeviceArray]] = None - cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + logits: jnp.ndarray = None + past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + cross_attentions: Optional[Tuple[jnp.ndarray]] = None @flax.struct.dataclass @@ -260,24 +260,24 @@ class FlaxMaskedLMOutput(ModelOutput): Base class for masked language models outputs. Args: - logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each - layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ - logits: jax_xla.DeviceArray = None - hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None - attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + logits: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None FlaxCausalLMOutput = FlaxMaskedLMOutput @@ -289,55 +289,55 @@ class FlaxSeq2SeqLMOutput(ModelOutput): Base class for sequence-to-sequence language models outputs. Args: - logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): - Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2 - tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional - tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + Tuple of :obj:`tuple(jnp.ndarray)` of length :obj:`config.n_layers`, with each tuple having 2 tensors of + shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each - layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + decoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + decoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. - cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads. - encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + encoder_last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each - layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + encoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + encoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the self-attention heads. """ - logits: jax_xla.DeviceArray = None - past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None - decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None - decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None - cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None - encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None - encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None - encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + logits: jnp.ndarray = None + past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None + decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None + decoder_attentions: Optional[Tuple[jnp.ndarray]] = None + cross_attentions: Optional[Tuple[jnp.ndarray]] = None + encoder_last_hidden_state: Optional[jnp.ndarray] = None + encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None + encoder_attentions: Optional[Tuple[jnp.ndarray]] = None @flax.struct.dataclass @@ -346,25 +346,25 @@ class FlaxNextSentencePredictorOutput(ModelOutput): Base class for outputs of models predicting if two sentences are consecutive or not. Args: - logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`): + logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, 2)`): Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax). - hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each - layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ - logits: jax_xla.DeviceArray = None - hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None - attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + logits: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None @flax.struct.dataclass @@ -373,24 +373,24 @@ class FlaxSequenceClassifierOutput(ModelOutput): Base class for outputs of sentence classification models. Args: - logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.num_labels)`): + logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, config.num_labels)`): Classification (or regression if config.num_labels==1) scores (before SoftMax). - hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each - layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ - logits: jax_xla.DeviceArray = None - hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None - attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + logits: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None @flax.struct.dataclass @@ -399,55 +399,55 @@ class FlaxSeq2SeqSequenceClassifierOutput(ModelOutput): Base class for outputs of sequence-to-sequence sentence classification models. Args: - logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.num_labels)`): + logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, config.num_labels)`): Classification (or regression if config.num_labels==1) scores (before SoftMax). - past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): - Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2 - tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional - tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + Tuple of :obj:`tuple(jnp.ndarray)` of length :obj:`config.n_layers`, with each tuple having 2 tensors of + shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each - layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + decoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + decoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. - cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads. - encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + encoder_last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each - layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + encoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + encoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the self-attention heads. """ - logits: jax_xla.DeviceArray = None - past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None - decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None - decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None - cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None - encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None - encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None - encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + logits: jnp.ndarray = None + past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None + decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None + decoder_attentions: Optional[Tuple[jnp.ndarray]] = None + cross_attentions: Optional[Tuple[jnp.ndarray]] = None + encoder_last_hidden_state: Optional[jnp.ndarray] = None + encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None + encoder_attentions: Optional[Tuple[jnp.ndarray]] = None @flax.struct.dataclass @@ -456,26 +456,26 @@ class FlaxMultipleChoiceModelOutput(ModelOutput): Base class for outputs of multiple choice models. Args: - logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, num_choices)`): + logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, num_choices)`): `num_choices` is the second dimension of the input tensors. (see `input_ids` above). Classification scores (before SoftMax). - hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each - layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ - logits: jax_xla.DeviceArray = None - hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None - attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + logits: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None @flax.struct.dataclass @@ -484,24 +484,24 @@ class FlaxTokenClassifierOutput(ModelOutput): Base class for outputs of token classification models. Args: - logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.num_labels)`): + logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.num_labels)`): Classification scores (before SoftMax). - hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each - layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ - logits: jax_xla.DeviceArray = None - hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None - attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + logits: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None @flax.struct.dataclass @@ -510,27 +510,27 @@ class FlaxQuestionAnsweringModelOutput(ModelOutput): Base class for outputs of question answering models. Args: - start_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`): + start_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`): Span-start scores (before SoftMax). - end_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`): + end_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`): Span-end scores (before SoftMax). - hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each - layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ - start_logits: jax_xla.DeviceArray = None - end_logits: jax_xla.DeviceArray = None - hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None - attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + start_logits: jnp.ndarray = None + end_logits: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None @flax.struct.dataclass @@ -539,55 +539,55 @@ class FlaxSeq2SeqQuestionAnsweringModelOutput(ModelOutput): Base class for outputs of sequence-to-sequence question answering models. Args: - start_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`): + start_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`): Span-start scores (before SoftMax). - end_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`): + end_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`): Span-end scores (before SoftMax). - past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): - Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2 - tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional - tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + Tuple of :obj:`tuple(jnp.ndarray)` of length :obj:`config.n_layers`, with each tuple having 2 tensors of + shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each - layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + decoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + decoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. - cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads. - encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + encoder_last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each - layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + encoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + encoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the self-attention heads. """ - start_logits: jax_xla.DeviceArray = None - end_logits: jax_xla.DeviceArray = None - past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None - decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None - decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None - cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None - encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None - encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None - encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + start_logits: jnp.ndarray = None + end_logits: jnp.ndarray = None + past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None + decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None + decoder_attentions: Optional[Tuple[jnp.ndarray]] = None + cross_attentions: Optional[Tuple[jnp.ndarray]] = None + encoder_last_hidden_state: Optional[jnp.ndarray] = None + encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None + encoder_attentions: Optional[Tuple[jnp.ndarray]] = None diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index 2ec002cd3c..5da0fa01e8 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -21,7 +21,6 @@ import flax import flax.linen as nn import jax import jax.numpy as jnp -import jaxlib.xla_extension as jax_xla from flax.core.frozen_dict import FrozenDict from flax.linen.attention import dot_product_attention_weights from jax import lax @@ -61,28 +60,28 @@ class FlaxBertForPreTrainingOutput(ModelOutput): Output type of :class:`~transformers.BertForPreTraining`. Args: - prediction_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + prediction_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - seq_relationship_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`): + seq_relationship_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, 2)`): Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax). - hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each - layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ - prediction_logits: jax_xla.DeviceArray = None - seq_relationship_logits: jax_xla.DeviceArray = None - hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None - attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + prediction_logits: jnp.ndarray = None + seq_relationship_logits: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None BERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/big_bird/modeling_flax_big_bird.py b/src/transformers/models/big_bird/modeling_flax_big_bird.py index 20526cc14c..46809667dd 100644 --- a/src/transformers/models/big_bird/modeling_flax_big_bird.py +++ b/src/transformers/models/big_bird/modeling_flax_big_bird.py @@ -21,7 +21,6 @@ import flax import flax.linen as nn import jax import jax.numpy as jnp -import jaxlib.xla_extension as jax_xla from flax.core.frozen_dict import FrozenDict from flax.linen.attention import dot_product_attention_weights from jax import lax @@ -59,28 +58,28 @@ class FlaxBigBirdForPreTrainingOutput(ModelOutput): Output type of :class:`~transformers.BigBirdForPreTraining`. Args: - prediction_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + prediction_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - seq_relationship_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`): + seq_relationship_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, 2)`): Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax). - hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each - layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ - prediction_logits: jax_xla.DeviceArray = None - seq_relationship_logits: jax_xla.DeviceArray = None - hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None - attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + prediction_logits: jnp.ndarray = None + seq_relationship_logits: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None @flax.struct.dataclass @@ -89,30 +88,30 @@ class FlaxBigBirdForQuestionAnsweringModelOutput(ModelOutput): Base class for outputs of question answering models. Args: - start_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`): + start_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`): Span-start scores (before SoftMax). - end_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`): + end_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`): Span-end scores (before SoftMax). - pooled_output (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, hidden_size)`): + pooled_output (:obj:`jnp.ndarray` of shape :obj:`(batch_size, hidden_size)`): pooled_output returned by FlaxBigBirdModel. - hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each - layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ - start_logits: jax_xla.DeviceArray = None - end_logits: jax_xla.DeviceArray = None - pooled_output: jax_xla.DeviceArray = None - hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None - attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + start_logits: jnp.ndarray = None + end_logits: jnp.ndarray = None + pooled_output: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None BIG_BIRD_START_DOCSTRING = r""" diff --git a/src/transformers/models/clip/modeling_flax_clip.py b/src/transformers/models/clip/modeling_flax_clip.py index 4b3a311d1d..e2be39da27 100644 --- a/src/transformers/models/clip/modeling_flax_clip.py +++ b/src/transformers/models/clip/modeling_flax_clip.py @@ -19,7 +19,6 @@ import flax import flax.linen as nn import jax import jax.numpy as jnp -import jaxlib.xla_extension as jax_xla from flax.core.frozen_dict import FrozenDict from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights @@ -156,16 +155,16 @@ CLIP_INPUTS_DOCSTRING = r""" class FlaxCLIPOutput(ModelOutput): """ Args: - logits_per_image:(:obj:`jax_xla.DeviceArray` of shape :obj:`(image_batch_size, text_batch_size)`): + logits_per_image:(:obj:`jnp.ndarray` of shape :obj:`(image_batch_size, text_batch_size)`): The scaled dot product scores between :obj:`image_embeds` and :obj:`text_embeds`. This represents the image-text similarity scores. - logits_per_text:(:obj:`jax_xla.DeviceArray` of shape :obj:`(text_batch_size, image_batch_size)`): + logits_per_text:(:obj:`jnp.ndarray` of shape :obj:`(text_batch_size, image_batch_size)`): The scaled dot product scores between :obj:`text_embeds` and :obj:`image_embeds`. This represents the text-image similarity scores. - text_embeds(:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): + text_embeds(:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The text embeddings obtained by applying the projection layer to the pooled output of :class:`~transformers.FlaxCLIPTextModel`. - image_embeds(:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): + image_embeds(:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The image embeddings obtained by applying the projection layer to the pooled output of :class:`~transformers.FlaxCLIPVisionModel`. text_model_output(:obj:`FlaxBaseModelOutputWithPooling`): @@ -174,10 +173,10 @@ class FlaxCLIPOutput(ModelOutput): The output of the :class:`~transformers.FlaxCLIPVisionModel`. """ - logits_per_image: jax_xla.DeviceArray = None - logits_per_text: jax_xla.DeviceArray = None - text_embeds: jax_xla.DeviceArray = None - image_embeds: jax_xla.DeviceArray = None + logits_per_image: jnp.ndarray = None + logits_per_text: jnp.ndarray = None + text_embeds: jnp.ndarray = None + image_embeds: jnp.ndarray = None text_model_output: FlaxBaseModelOutputWithPooling = None vision_model_output: FlaxBaseModelOutputWithPooling = None @@ -801,8 +800,8 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel): `What are input IDs? <../glossary.html#input-ids>`__ Returns: - text_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The text embeddings - obtained by applying the projection layer to the pooled output of :class:`~transformers.FlaxCLIPTextModel`. + text_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of :class:`~transformers.FlaxCLIPTextModel`. Examples:: @@ -855,9 +854,8 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel): :meth:`transformers.CLIPFeatureExtractor.__call__` for details. Returns: - image_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The image embeddings - obtained by applying the projection layer to the pooled output of - :class:`~transformers.FlaxCLIPVisionModel` + image_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The image embeddings obtained + by applying the projection layer to the pooled output of :class:`~transformers.FlaxCLIPVisionModel` Examples:: diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py index cbd7b00c6e..43c38fcdd3 100644 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -21,7 +21,6 @@ import flax import flax.linen as nn import jax import jax.numpy as jnp -import jaxlib.xla_extension as jax_xla from flax.core.frozen_dict import FrozenDict from flax.linen.attention import dot_product_attention_weights from jax import lax @@ -60,24 +59,24 @@ class FlaxElectraForPreTrainingOutput(ModelOutput): Output type of :class:`~transformers.ElectraForPreTraining`. Args: - logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each - layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. + hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of + shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. + attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ - logits: jax_xla.DeviceArray = None - hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None - attentions: Optional[Tuple[jax_xla.DeviceArray]] = None + logits: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None ELECTRA_START_DOCSTRING = r""" diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 2646751459..c959a04ac2 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -44,7 +44,6 @@ if is_flax_available(): import jax import jax.numpy as jnp - import jaxlib.xla_extension as jax_xla from flax.core.frozen_dict import unfreeze from flax.traverse_util import flatten_dict from transformers import ( @@ -127,7 +126,7 @@ class FlaxModelTesterMixin: if "ForMultipleChoice" in model_class.__name__: inputs_dict = { k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1])) - if isinstance(v, (jax_xla.DeviceArray, np.ndarray)) + if isinstance(v, (jnp.ndarray, np.ndarray)) else v for k, v in inputs_dict.items() }