[Flax] Align jax flax device name (#12987)
* [Flax] Align device name in docs * make style * fix import error
This commit is contained in:
committed by
GitHub
parent
07df5578d9
commit
da9754a3a0
@@ -224,7 +224,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
|||||||
`What are input IDs? <../glossary.html#input-ids>`__
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
text_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The text embeddings
|
text_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The text embeddings
|
||||||
obtained by applying the projection layer to the pooled output of text model.
|
obtained by applying the projection layer to the pooled output of text model.
|
||||||
"""
|
"""
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
@@ -273,7 +273,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
|||||||
:meth:`transformers.ImageFeatureExtractionMixin.__call__` for details.
|
:meth:`transformers.ImageFeatureExtractionMixin.__call__` for details.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
image_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The image embeddings
|
image_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The image embeddings
|
||||||
obtained by applying the projection layer to the pooled output of vision model.
|
obtained by applying the projection layer to the pooled output of vision model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from abc import ABC
|
|||||||
import jax
|
import jax
|
||||||
import jax.lax as lax
|
import jax.lax as lax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jaxlib.xla_extension as jax_xla
|
|
||||||
|
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
from .utils.logging import get_logger
|
from .utils.logging import get_logger
|
||||||
@@ -30,7 +29,7 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
|
input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
Indices of input sequence tokens in the vocabulary.
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
|
||||||
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
|
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
|
||||||
@@ -38,14 +37,14 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
|||||||
details.
|
details.
|
||||||
|
|
||||||
`What are input IDs? <../glossary.html#input-ids>`__
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
scores (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.vocab_size)`):
|
scores (:obj:`jnp.ndarray` of shape :obj:`(batch_size, config.vocab_size)`):
|
||||||
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
|
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
|
||||||
search or log softmax for each vocabulary token when using beam search
|
search or log softmax for each vocabulary token when using beam search
|
||||||
kwargs:
|
kwargs:
|
||||||
Additional logits processor specific kwargs.
|
Additional logits processor specific kwargs.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores.
|
:obj:`jnp.ndarray` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -54,7 +53,7 @@ class FlaxLogitsProcessor(ABC):
|
|||||||
"""Abstract base class for all logit processors that can be applied during generation."""
|
"""Abstract base class for all logit processors that can be applied during generation."""
|
||||||
|
|
||||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
|
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray:
|
||||||
"""Flax method for processing logits."""
|
"""Flax method for processing logits."""
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||||
@@ -65,7 +64,7 @@ class FlaxLogitsWarper(ABC):
|
|||||||
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
|
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
|
||||||
|
|
||||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
|
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray:
|
||||||
"""Flax method for warping logits."""
|
"""Flax method for warping logits."""
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||||
@@ -81,9 +80,7 @@ class FlaxLogitsProcessorList(list):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
def __call__(
|
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int, **kwargs) -> jnp.ndarray:
|
||||||
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int, **kwargs
|
|
||||||
) -> jax_xla.DeviceArray:
|
|
||||||
for processor in self:
|
for processor in self:
|
||||||
function_args = inspect.signature(processor.__call__).parameters
|
function_args = inspect.signature(processor.__call__).parameters
|
||||||
if len(function_args) > 3:
|
if len(function_args) > 3:
|
||||||
@@ -111,9 +108,7 @@ class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
|
|||||||
|
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
|
||||||
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
|
|
||||||
) -> jax_xla.DeviceArray:
|
|
||||||
scores = scores / self.temperature
|
scores = scores / self.temperature
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
@@ -141,9 +136,7 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
|
|||||||
self.filter_value = filter_value
|
self.filter_value = filter_value
|
||||||
self.min_tokens_to_keep = min_tokens_to_keep
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
|
||||||
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
|
|
||||||
) -> jax_xla.DeviceArray:
|
|
||||||
topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1])
|
topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1])
|
||||||
|
|
||||||
mask_scores = jnp.full_like(scores, self.filter_value)
|
mask_scores = jnp.full_like(scores, self.filter_value)
|
||||||
@@ -183,9 +176,7 @@ class FlaxTopKLogitsWarper(FlaxLogitsWarper):
|
|||||||
self.filter_value = filter_value
|
self.filter_value = filter_value
|
||||||
self.min_tokens_to_keep = min_tokens_to_keep
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
|
||||||
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
|
|
||||||
) -> jax_xla.DeviceArray:
|
|
||||||
batch_size, vocab_size = scores.shape
|
batch_size, vocab_size = scores.shape
|
||||||
next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value)
|
next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value)
|
||||||
|
|
||||||
@@ -212,9 +203,7 @@ class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor):
|
|||||||
def __init__(self, bos_token_id: int):
|
def __init__(self, bos_token_id: int):
|
||||||
self.bos_token_id = bos_token_id
|
self.bos_token_id = bos_token_id
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
|
||||||
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
|
|
||||||
) -> jax_xla.DeviceArray:
|
|
||||||
new_scores = jnp.full(scores.shape, -float("inf"))
|
new_scores = jnp.full(scores.shape, -float("inf"))
|
||||||
|
|
||||||
apply_penalty = 1 - jnp.bool_(cur_len - 1)
|
apply_penalty = 1 - jnp.bool_(cur_len - 1)
|
||||||
@@ -242,9 +231,7 @@ class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor):
|
|||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
|
||||||
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
|
|
||||||
) -> jax_xla.DeviceArray:
|
|
||||||
new_scores = jnp.full(scores.shape, -float("inf"))
|
new_scores = jnp.full(scores.shape, -float("inf"))
|
||||||
|
|
||||||
apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1)
|
apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1)
|
||||||
@@ -277,9 +264,7 @@ class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor):
|
|||||||
self.min_length = min_length
|
self.min_length = min_length
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
|
||||||
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
|
|
||||||
) -> jax_xla.DeviceArray:
|
|
||||||
|
|
||||||
# create boolean flag to decide if min length penalty should be applied
|
# create boolean flag to decide if min length penalty should be applied
|
||||||
apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)
|
apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ import numpy as np
|
|||||||
import flax
|
import flax
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jaxlib.xla_extension as jax_xla
|
|
||||||
from jax import lax
|
from jax import lax
|
||||||
|
|
||||||
from .file_utils import ModelOutput
|
from .file_utils import ModelOutput
|
||||||
@@ -49,11 +48,11 @@ class FlaxGreedySearchOutput(ModelOutput):
|
|||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`):
|
sequences (:obj:`jnp.ndarray` of shape :obj:`(batch_size, max_length)`):
|
||||||
The generated sequences.
|
The generated sequences.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: jax_xla.DeviceArray = None
|
sequences: jnp.ndarray = None
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
@@ -63,11 +62,11 @@ class FlaxSampleOutput(ModelOutput):
|
|||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`):
|
sequences (:obj:`jnp.ndarray` of shape :obj:`(batch_size, max_length)`):
|
||||||
The generated sequences.
|
The generated sequences.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: jax_xla.DeviceArray = None
|
sequences: jnp.ndarray = None
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
@@ -77,44 +76,44 @@ class FlaxBeamSearchOutput(ModelOutput):
|
|||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`):
|
sequences (:obj:`jnp.ndarray` of shape :obj:`(batch_size, max_length)`):
|
||||||
The generated sequences.
|
The generated sequences.
|
||||||
scores (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size,)`):
|
scores (:obj:`jnp.ndarray` of shape :obj:`(batch_size,)`):
|
||||||
The scores (log probabilites) of the generated sequences.
|
The scores (log probabilites) of the generated sequences.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: jax_xla.DeviceArray = None
|
sequences: jnp.ndarray = None
|
||||||
scores: jax_xla.DeviceArray = None
|
scores: jnp.ndarray = None
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
class GreedyState:
|
class GreedyState:
|
||||||
cur_len: jax_xla.DeviceArray
|
cur_len: jnp.ndarray
|
||||||
sequences: jax_xla.DeviceArray
|
sequences: jnp.ndarray
|
||||||
running_token: jax_xla.DeviceArray
|
running_token: jnp.ndarray
|
||||||
is_sent_finished: jax_xla.DeviceArray
|
is_sent_finished: jnp.ndarray
|
||||||
model_kwargs: Dict[str, jax_xla.DeviceArray]
|
model_kwargs: Dict[str, jnp.ndarray]
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
class SampleState:
|
class SampleState:
|
||||||
cur_len: jax_xla.DeviceArray
|
cur_len: jnp.ndarray
|
||||||
sequences: jax_xla.DeviceArray
|
sequences: jnp.ndarray
|
||||||
running_token: jax_xla.DeviceArray
|
running_token: jnp.ndarray
|
||||||
is_sent_finished: jax_xla.DeviceArray
|
is_sent_finished: jnp.ndarray
|
||||||
prng_key: jax_xla.DeviceArray
|
prng_key: jnp.ndarray
|
||||||
model_kwargs: Dict[str, jax_xla.DeviceArray]
|
model_kwargs: Dict[str, jnp.ndarray]
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
class BeamSearchState:
|
class BeamSearchState:
|
||||||
cur_len: jax_xla.DeviceArray
|
cur_len: jnp.ndarray
|
||||||
running_sequences: jax_xla.DeviceArray
|
running_sequences: jnp.ndarray
|
||||||
running_scores: jax_xla.DeviceArray
|
running_scores: jnp.ndarray
|
||||||
sequences: jax_xla.DeviceArray
|
sequences: jnp.ndarray
|
||||||
scores: jax_xla.DeviceArray
|
scores: jnp.ndarray
|
||||||
is_sent_finished: jax_xla.DeviceArray
|
is_sent_finished: jnp.ndarray
|
||||||
model_kwargs: Dict[str, jax_xla.DeviceArray]
|
model_kwargs: Dict[str, jnp.ndarray]
|
||||||
|
|
||||||
|
|
||||||
class FlaxGenerationMixin:
|
class FlaxGenerationMixin:
|
||||||
@@ -156,14 +155,14 @@ class FlaxGenerationMixin:
|
|||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
input_ids: jax_xla.DeviceArray,
|
input_ids: jnp.ndarray,
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
bos_token_id: Optional[int] = None,
|
bos_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[int] = None,
|
||||||
decoder_start_token_id: Optional[int] = None,
|
decoder_start_token_id: Optional[int] = None,
|
||||||
do_sample: Optional[bool] = None,
|
do_sample: Optional[bool] = None,
|
||||||
prng_key: Optional[jax_xla.DeviceArray] = None,
|
prng_key: Optional[jnp.ndarray] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
@@ -175,7 +174,7 @@ class FlaxGenerationMixin:
|
|||||||
length_penalty: Optional[float] = None,
|
length_penalty: Optional[float] = None,
|
||||||
early_stopping: Optional[bool] = None,
|
early_stopping: Optional[bool] = None,
|
||||||
trace: bool = True,
|
trace: bool = True,
|
||||||
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
params: Optional[Dict[str, jnp.ndarray]] = None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -191,7 +190,7 @@ class FlaxGenerationMixin:
|
|||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
|
|
||||||
input_ids (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
The sequence used as a prompt for the generation.
|
The sequence used as a prompt for the generation.
|
||||||
max_length (:obj:`int`, `optional`, defaults to 20):
|
max_length (:obj:`int`, `optional`, defaults to 20):
|
||||||
The maximum length of the sequence to be generated.
|
The maximum length of the sequence to be generated.
|
||||||
@@ -217,7 +216,7 @@ class FlaxGenerationMixin:
|
|||||||
trace (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
trace (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to
|
Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to
|
||||||
a considerably slower runtime.
|
a considerably slower runtime.
|
||||||
params (:obj:`Dict[str, jax_xla.DeviceArray]`, `optional`):
|
params (:obj:`Dict[str, jnp.ndarray]`, `optional`):
|
||||||
Optionally the model parameters can be passed. Can be useful for parallelized generation.
|
Optionally the model parameters can be passed. Can be useful for parallelized generation.
|
||||||
model_kwargs:
|
model_kwargs:
|
||||||
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
|
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
|
||||||
@@ -395,8 +394,8 @@ class FlaxGenerationMixin:
|
|||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[int] = None,
|
||||||
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
||||||
trace: bool = True,
|
trace: bool = True,
|
||||||
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
params: Optional[Dict[str, jnp.ndarray]] = None,
|
||||||
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
|
||||||
):
|
):
|
||||||
# init values
|
# init values
|
||||||
max_length = max_length if max_length is not None else self.config.max_length
|
max_length = max_length if max_length is not None else self.config.max_length
|
||||||
@@ -479,12 +478,12 @@ class FlaxGenerationMixin:
|
|||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[int] = None,
|
||||||
prng_key: Optional[jax_xla.DeviceArray] = None,
|
prng_key: Optional[jnp.ndarray] = None,
|
||||||
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
||||||
logits_warper: Optional[FlaxLogitsProcessorList] = None,
|
logits_warper: Optional[FlaxLogitsProcessorList] = None,
|
||||||
trace: bool = True,
|
trace: bool = True,
|
||||||
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
params: Optional[Dict[str, jnp.ndarray]] = None,
|
||||||
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
|
||||||
):
|
):
|
||||||
# init values
|
# init values
|
||||||
max_length = max_length if max_length is not None else self.config.max_length
|
max_length = max_length if max_length is not None else self.config.max_length
|
||||||
@@ -580,8 +579,8 @@ class FlaxGenerationMixin:
|
|||||||
early_stopping: Optional[bool] = None,
|
early_stopping: Optional[bool] = None,
|
||||||
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
||||||
trace: bool = True,
|
trace: bool = True,
|
||||||
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
params: Optional[Dict[str, jnp.ndarray]] = None,
|
||||||
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
This beam search function is heavily inspired by Flax's official example:
|
This beam search function is heavily inspired by Flax's official example:
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
import flax
|
import flax
|
||||||
import jaxlib.xla_extension as jax_xla
|
import jax.numpy as jnp
|
||||||
|
|
||||||
from .file_utils import ModelOutput
|
from .file_utils import ModelOutput
|
||||||
|
|
||||||
@@ -25,24 +25,24 @@ class FlaxBaseModelOutput(ModelOutput):
|
|||||||
Base class for model's outputs, with potential hidden states and attentions.
|
Base class for model's outputs, with potential hidden states and attentions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||||
Sequence of hidden-states at the output of the last layer of the model.
|
Sequence of hidden-states at the output of the last layer of the model.
|
||||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
last_hidden_state: jax_xla.DeviceArray = None
|
last_hidden_state: jnp.ndarray = None
|
||||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
@@ -51,28 +51,28 @@ class FlaxBaseModelOutputWithPast(ModelOutput):
|
|||||||
Base class for model's outputs, with potential hidden states and attentions.
|
Base class for model's outputs, with potential hidden states and attentions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||||
Sequence of hidden-states at the output of the last layer of the model.
|
Sequence of hidden-states at the output of the last layer of the model.
|
||||||
past_key_values (:obj:`Dict[str, jax_xla.DeviceArray]`):
|
past_key_values (:obj:`Dict[str, jnp.ndarray]`):
|
||||||
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
|
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
|
||||||
auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`.
|
auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`.
|
||||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
last_hidden_state: jax_xla.DeviceArray = None
|
last_hidden_state: jnp.ndarray = None
|
||||||
past_key_values: Optional[Dict[str, jax_xla.DeviceArray]] = None
|
past_key_values: Optional[Dict[str, jnp.ndarray]] = None
|
||||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
@@ -81,29 +81,29 @@ class FlaxBaseModelOutputWithPooling(ModelOutput):
|
|||||||
Base class for model's outputs that also contains a pooling of the last hidden states.
|
Base class for model's outputs that also contains a pooling of the last hidden states.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||||
Sequence of hidden-states at the output of the last layer of the model.
|
Sequence of hidden-states at the output of the last layer of the model.
|
||||||
pooler_output (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, hidden_size)`):
|
pooler_output (:obj:`jnp.ndarray` of shape :obj:`(batch_size, hidden_size)`):
|
||||||
Last layer hidden-state of the first token of the sequence (classification token) further processed by a
|
Last layer hidden-state of the first token of the sequence (classification token) further processed by a
|
||||||
Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
|
Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
|
||||||
prediction (classification) objective during pretraining.
|
prediction (classification) objective during pretraining.
|
||||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
last_hidden_state: jax_xla.DeviceArray = None
|
last_hidden_state: jnp.ndarray = None
|
||||||
pooler_output: jax_xla.DeviceArray = None
|
pooler_output: jnp.ndarray = None
|
||||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
@@ -112,44 +112,44 @@ class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
|
|||||||
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
|
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||||
Sequence of hidden-states at the output of the last layer of the model.
|
Sequence of hidden-states at the output of the last layer of the model.
|
||||||
|
|
||||||
If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size,
|
If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size,
|
||||||
1, hidden_size)` is output.
|
1, hidden_size)` is output.
|
||||||
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
|
Tuple of :obj:`tuple(jnp.ndarray)` of length :obj:`config.n_layers`, with each tuple having 2 tensors of
|
||||||
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
|
||||||
``config.is_encoder_decoder=True`` 2 additional tensors of shape :obj:`(batch_size, num_heads,
|
``config.is_encoder_decoder=True`` 2 additional tensors of shape :obj:`(batch_size, num_heads,
|
||||||
encoder_sequence_length, embed_size_per_head)`.
|
encoder_sequence_length, embed_size_per_head)`.
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
|
||||||
``config.is_encoder_decoder=True`` in the cross-attention blocks) that can be used (see
|
``config.is_encoder_decoder=True`` in the cross-attention blocks) that can be used (see
|
||||||
:obj:`past_key_values` input) to speed up sequential decoding.
|
:obj:`past_key_values` input) to speed up sequential decoding.
|
||||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` and ``config.add_cross_attention=True`` is passed or when ``config.output_attentions=True``):
|
cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` and ``config.add_cross_attention=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||||
weighted average in the cross-attention heads.
|
weighted average in the cross-attention heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
last_hidden_state: jax_xla.DeviceArray = None
|
last_hidden_state: jnp.ndarray = None
|
||||||
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
|
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
|
||||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
@@ -159,58 +159,58 @@ class FlaxSeq2SeqModelOutput(ModelOutput):
|
|||||||
decoding.
|
decoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
||||||
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
Sequence of hidden-states at the output of the last layer of the decoder of the model.
|
||||||
|
|
||||||
If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size,
|
If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size,
|
||||||
1, hidden_size)` is output.
|
1, hidden_size)` is output.
|
||||||
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
|
Tuple of :obj:`tuple(jnp.ndarray)` of length :obj:`config.n_layers`, with each tuple having 2 tensors of
|
||||||
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
|
shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
||||||
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
|
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
|
||||||
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
decoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
|
||||||
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
decoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
||||||
self-attention heads.
|
self-attention heads.
|
||||||
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||||
weighted average in the cross-attention heads.
|
weighted average in the cross-attention heads.
|
||||||
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
encoder_last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
||||||
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
encoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
|
||||||
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
encoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
|
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
|
||||||
self-attention heads.
|
self-attention heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
last_hidden_state: jax_xla.DeviceArray = None
|
last_hidden_state: jnp.ndarray = None
|
||||||
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
|
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||||
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None
|
encoder_last_hidden_state: Optional[jnp.ndarray] = None
|
||||||
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||||
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
@@ -219,39 +219,39 @@ class FlaxCausalLMOutputWithCrossAttentions(ModelOutput):
|
|||||||
Base class for causal language model (or autoregressive) outputs.
|
Base class for causal language model (or autoregressive) outputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Cross attentions weights after the attention softmax, used to compute the weighted average in the
|
Cross attentions weights after the attention softmax, used to compute the weighted average in the
|
||||||
cross-attention heads.
|
cross-attention heads.
|
||||||
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` tuples of length :obj:`config.n_layers`, with each tuple containing the
|
Tuple of :obj:`jnp.ndarray` tuples of length :obj:`config.n_layers`, with each tuple containing the cached
|
||||||
cached key, value states of the self-attention and the cross-attention layers if model is used in
|
key, value states of the self-attention and the cross-attention layers if model is used in encoder-decoder
|
||||||
encoder-decoder setting. Only relevant if ``config.is_decoder = True``.
|
setting. Only relevant if ``config.is_decoder = True``.
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
||||||
:obj:`past_key_values` input) to speed up sequential decoding.
|
:obj:`past_key_values` input) to speed up sequential decoding.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logits: jax_xla.DeviceArray = None
|
logits: jnp.ndarray = None
|
||||||
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
|
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
|
||||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
@@ -260,24 +260,24 @@ class FlaxMaskedLMOutput(ModelOutput):
|
|||||||
Base class for masked language models outputs.
|
Base class for masked language models outputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logits: jax_xla.DeviceArray = None
|
logits: jnp.ndarray = None
|
||||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
|
|
||||||
|
|
||||||
FlaxCausalLMOutput = FlaxMaskedLMOutput
|
FlaxCausalLMOutput = FlaxMaskedLMOutput
|
||||||
@@ -289,55 +289,55 @@ class FlaxSeq2SeqLMOutput(ModelOutput):
|
|||||||
Base class for sequence-to-sequence language models outputs.
|
Base class for sequence-to-sequence language models outputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
|
Tuple of :obj:`tuple(jnp.ndarray)` of length :obj:`config.n_layers`, with each tuple having 2 tensors of
|
||||||
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
|
shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
||||||
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
|
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
|
||||||
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
decoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
|
||||||
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
decoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
||||||
self-attention heads.
|
self-attention heads.
|
||||||
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||||
weighted average in the cross-attention heads.
|
weighted average in the cross-attention heads.
|
||||||
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
encoder_last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
||||||
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
encoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
|
||||||
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
encoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
|
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
|
||||||
self-attention heads.
|
self-attention heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logits: jax_xla.DeviceArray = None
|
logits: jnp.ndarray = None
|
||||||
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
|
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||||
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None
|
encoder_last_hidden_state: Optional[jnp.ndarray] = None
|
||||||
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||||
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
@@ -346,25 +346,25 @@ class FlaxNextSentencePredictorOutput(ModelOutput):
|
|||||||
Base class for outputs of models predicting if two sentences are consecutive or not.
|
Base class for outputs of models predicting if two sentences are consecutive or not.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`):
|
logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, 2)`):
|
||||||
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
||||||
before SoftMax).
|
before SoftMax).
|
||||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logits: jax_xla.DeviceArray = None
|
logits: jnp.ndarray = None
|
||||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
@@ -373,24 +373,24 @@ class FlaxSequenceClassifierOutput(ModelOutput):
|
|||||||
Base class for outputs of sentence classification models.
|
Base class for outputs of sentence classification models.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.num_labels)`):
|
logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, config.num_labels)`):
|
||||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logits: jax_xla.DeviceArray = None
|
logits: jnp.ndarray = None
|
||||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
@@ -399,55 +399,55 @@ class FlaxSeq2SeqSequenceClassifierOutput(ModelOutput):
|
|||||||
Base class for outputs of sequence-to-sequence sentence classification models.
|
Base class for outputs of sequence-to-sequence sentence classification models.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.num_labels)`):
|
logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, config.num_labels)`):
|
||||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||||
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
|
Tuple of :obj:`tuple(jnp.ndarray)` of length :obj:`config.n_layers`, with each tuple having 2 tensors of
|
||||||
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
|
shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
||||||
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
|
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
|
||||||
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
decoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
|
||||||
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
decoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
||||||
self-attention heads.
|
self-attention heads.
|
||||||
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||||
weighted average in the cross-attention heads.
|
weighted average in the cross-attention heads.
|
||||||
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
encoder_last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
||||||
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
encoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
|
||||||
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
encoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
|
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
|
||||||
self-attention heads.
|
self-attention heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logits: jax_xla.DeviceArray = None
|
logits: jnp.ndarray = None
|
||||||
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
|
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||||
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None
|
encoder_last_hidden_state: Optional[jnp.ndarray] = None
|
||||||
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||||
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
@@ -456,26 +456,26 @@ class FlaxMultipleChoiceModelOutput(ModelOutput):
|
|||||||
Base class for outputs of multiple choice models.
|
Base class for outputs of multiple choice models.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, num_choices)`):
|
logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, num_choices)`):
|
||||||
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
|
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
|
||||||
|
|
||||||
Classification scores (before SoftMax).
|
Classification scores (before SoftMax).
|
||||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logits: jax_xla.DeviceArray = None
|
logits: jnp.ndarray = None
|
||||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
@@ -484,24 +484,24 @@ class FlaxTokenClassifierOutput(ModelOutput):
|
|||||||
Base class for outputs of token classification models.
|
Base class for outputs of token classification models.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
|
logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
|
||||||
Classification scores (before SoftMax).
|
Classification scores (before SoftMax).
|
||||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logits: jax_xla.DeviceArray = None
|
logits: jnp.ndarray = None
|
||||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
@@ -510,27 +510,27 @@ class FlaxQuestionAnsweringModelOutput(ModelOutput):
|
|||||||
Base class for outputs of question answering models.
|
Base class for outputs of question answering models.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
start_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
|
start_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
Span-start scores (before SoftMax).
|
Span-start scores (before SoftMax).
|
||||||
end_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
|
end_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
Span-end scores (before SoftMax).
|
Span-end scores (before SoftMax).
|
||||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
start_logits: jax_xla.DeviceArray = None
|
start_logits: jnp.ndarray = None
|
||||||
end_logits: jax_xla.DeviceArray = None
|
end_logits: jnp.ndarray = None
|
||||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
@@ -539,55 +539,55 @@ class FlaxSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
|||||||
Base class for outputs of sequence-to-sequence question answering models.
|
Base class for outputs of sequence-to-sequence question answering models.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
start_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
|
start_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
Span-start scores (before SoftMax).
|
Span-start scores (before SoftMax).
|
||||||
end_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
|
end_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
Span-end scores (before SoftMax).
|
Span-end scores (before SoftMax).
|
||||||
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
past_key_values (:obj:`tuple(tuple(jnp.ndarray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
|
Tuple of :obj:`tuple(jnp.ndarray)` of length :obj:`config.n_layers`, with each tuple having 2 tensors of
|
||||||
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
|
shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
||||||
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
|
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
|
||||||
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
decoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
|
||||||
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
decoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
|
||||||
self-attention heads.
|
self-attention heads.
|
||||||
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
cross_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
|
||||||
weighted average in the cross-attention heads.
|
weighted average in the cross-attention heads.
|
||||||
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
encoder_last_hidden_state (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
Sequence of hidden-states at the output of the last layer of the encoder of the model.
|
||||||
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
encoder_hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
|
||||||
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
encoder_attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
|
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
|
||||||
self-attention heads.
|
self-attention heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
start_logits: jax_xla.DeviceArray = None
|
start_logits: jnp.ndarray = None
|
||||||
end_logits: jax_xla.DeviceArray = None
|
end_logits: jnp.ndarray = None
|
||||||
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
|
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
|
||||||
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||||
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None
|
encoder_last_hidden_state: Optional[jnp.ndarray] = None
|
||||||
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||||
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import flax
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jaxlib.xla_extension as jax_xla
|
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
from jax import lax
|
from jax import lax
|
||||||
@@ -61,28 +60,28 @@ class FlaxBertForPreTrainingOutput(ModelOutput):
|
|||||||
Output type of :class:`~transformers.BertForPreTraining`.
|
Output type of :class:`~transformers.BertForPreTraining`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prediction_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
prediction_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
seq_relationship_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`):
|
seq_relationship_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, 2)`):
|
||||||
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
||||||
before SoftMax).
|
before SoftMax).
|
||||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
prediction_logits: jax_xla.DeviceArray = None
|
prediction_logits: jnp.ndarray = None
|
||||||
seq_relationship_logits: jax_xla.DeviceArray = None
|
seq_relationship_logits: jnp.ndarray = None
|
||||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
|
|
||||||
|
|
||||||
BERT_START_DOCSTRING = r"""
|
BERT_START_DOCSTRING = r"""
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import flax
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jaxlib.xla_extension as jax_xla
|
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
from jax import lax
|
from jax import lax
|
||||||
@@ -59,28 +58,28 @@ class FlaxBigBirdForPreTrainingOutput(ModelOutput):
|
|||||||
Output type of :class:`~transformers.BigBirdForPreTraining`.
|
Output type of :class:`~transformers.BigBirdForPreTraining`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prediction_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
prediction_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
seq_relationship_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`):
|
seq_relationship_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, 2)`):
|
||||||
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
||||||
before SoftMax).
|
before SoftMax).
|
||||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
prediction_logits: jax_xla.DeviceArray = None
|
prediction_logits: jnp.ndarray = None
|
||||||
seq_relationship_logits: jax_xla.DeviceArray = None
|
seq_relationship_logits: jnp.ndarray = None
|
||||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
@@ -89,30 +88,30 @@ class FlaxBigBirdForQuestionAnsweringModelOutput(ModelOutput):
|
|||||||
Base class for outputs of question answering models.
|
Base class for outputs of question answering models.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
start_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
|
start_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
Span-start scores (before SoftMax).
|
Span-start scores (before SoftMax).
|
||||||
end_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
|
end_logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
Span-end scores (before SoftMax).
|
Span-end scores (before SoftMax).
|
||||||
pooled_output (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, hidden_size)`):
|
pooled_output (:obj:`jnp.ndarray` of shape :obj:`(batch_size, hidden_size)`):
|
||||||
pooled_output returned by FlaxBigBirdModel.
|
pooled_output returned by FlaxBigBirdModel.
|
||||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
start_logits: jax_xla.DeviceArray = None
|
start_logits: jnp.ndarray = None
|
||||||
end_logits: jax_xla.DeviceArray = None
|
end_logits: jnp.ndarray = None
|
||||||
pooled_output: jax_xla.DeviceArray = None
|
pooled_output: jnp.ndarray = None
|
||||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
|
|
||||||
|
|
||||||
BIG_BIRD_START_DOCSTRING = r"""
|
BIG_BIRD_START_DOCSTRING = r"""
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import flax
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jaxlib.xla_extension as jax_xla
|
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict
|
||||||
from flax.linen import combine_masks, make_causal_mask
|
from flax.linen import combine_masks, make_causal_mask
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
@@ -156,16 +155,16 @@ CLIP_INPUTS_DOCSTRING = r"""
|
|||||||
class FlaxCLIPOutput(ModelOutput):
|
class FlaxCLIPOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
logits_per_image:(:obj:`jax_xla.DeviceArray` of shape :obj:`(image_batch_size, text_batch_size)`):
|
logits_per_image:(:obj:`jnp.ndarray` of shape :obj:`(image_batch_size, text_batch_size)`):
|
||||||
The scaled dot product scores between :obj:`image_embeds` and :obj:`text_embeds`. This represents the
|
The scaled dot product scores between :obj:`image_embeds` and :obj:`text_embeds`. This represents the
|
||||||
image-text similarity scores.
|
image-text similarity scores.
|
||||||
logits_per_text:(:obj:`jax_xla.DeviceArray` of shape :obj:`(text_batch_size, image_batch_size)`):
|
logits_per_text:(:obj:`jnp.ndarray` of shape :obj:`(text_batch_size, image_batch_size)`):
|
||||||
The scaled dot product scores between :obj:`text_embeds` and :obj:`image_embeds`. This represents the
|
The scaled dot product scores between :obj:`text_embeds` and :obj:`image_embeds`. This represents the
|
||||||
text-image similarity scores.
|
text-image similarity scores.
|
||||||
text_embeds(:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`):
|
text_embeds(:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`):
|
||||||
The text embeddings obtained by applying the projection layer to the pooled output of
|
The text embeddings obtained by applying the projection layer to the pooled output of
|
||||||
:class:`~transformers.FlaxCLIPTextModel`.
|
:class:`~transformers.FlaxCLIPTextModel`.
|
||||||
image_embeds(:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`):
|
image_embeds(:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`):
|
||||||
The image embeddings obtained by applying the projection layer to the pooled output of
|
The image embeddings obtained by applying the projection layer to the pooled output of
|
||||||
:class:`~transformers.FlaxCLIPVisionModel`.
|
:class:`~transformers.FlaxCLIPVisionModel`.
|
||||||
text_model_output(:obj:`FlaxBaseModelOutputWithPooling`):
|
text_model_output(:obj:`FlaxBaseModelOutputWithPooling`):
|
||||||
@@ -174,10 +173,10 @@ class FlaxCLIPOutput(ModelOutput):
|
|||||||
The output of the :class:`~transformers.FlaxCLIPVisionModel`.
|
The output of the :class:`~transformers.FlaxCLIPVisionModel`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logits_per_image: jax_xla.DeviceArray = None
|
logits_per_image: jnp.ndarray = None
|
||||||
logits_per_text: jax_xla.DeviceArray = None
|
logits_per_text: jnp.ndarray = None
|
||||||
text_embeds: jax_xla.DeviceArray = None
|
text_embeds: jnp.ndarray = None
|
||||||
image_embeds: jax_xla.DeviceArray = None
|
image_embeds: jnp.ndarray = None
|
||||||
text_model_output: FlaxBaseModelOutputWithPooling = None
|
text_model_output: FlaxBaseModelOutputWithPooling = None
|
||||||
vision_model_output: FlaxBaseModelOutputWithPooling = None
|
vision_model_output: FlaxBaseModelOutputWithPooling = None
|
||||||
|
|
||||||
@@ -801,8 +800,8 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
`What are input IDs? <../glossary.html#input-ids>`__
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
text_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The text embeddings
|
text_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The text embeddings obtained by
|
||||||
obtained by applying the projection layer to the pooled output of :class:`~transformers.FlaxCLIPTextModel`.
|
applying the projection layer to the pooled output of :class:`~transformers.FlaxCLIPTextModel`.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
@@ -855,9 +854,8 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
:meth:`transformers.CLIPFeatureExtractor.__call__` for details.
|
:meth:`transformers.CLIPFeatureExtractor.__call__` for details.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
image_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The image embeddings
|
image_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The image embeddings obtained
|
||||||
obtained by applying the projection layer to the pooled output of
|
by applying the projection layer to the pooled output of :class:`~transformers.FlaxCLIPVisionModel`
|
||||||
:class:`~transformers.FlaxCLIPVisionModel`
|
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import flax
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jaxlib.xla_extension as jax_xla
|
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
from jax import lax
|
from jax import lax
|
||||||
@@ -60,24 +59,24 @@ class FlaxElectraForPreTrainingOutput(ModelOutput):
|
|||||||
Output type of :class:`~transformers.ElectraForPreTraining`.
|
Output type of :class:`~transformers.ElectraForPreTraining`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
logits (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
||||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
hidden_states (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
|
Tuple of :obj:`jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of
|
||||||
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
attentions (:obj:`tuple(jnp.ndarray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
Tuple of :obj:`jnp.ndarray` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
|
||||||
sequence_length, sequence_length)`.
|
sequence_length)`.
|
||||||
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logits: jax_xla.DeviceArray = None
|
logits: jnp.ndarray = None
|
||||||
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
|
hidden_states: Optional[Tuple[jnp.ndarray]] = None
|
||||||
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
|
attentions: Optional[Tuple[jnp.ndarray]] = None
|
||||||
|
|
||||||
|
|
||||||
ELECTRA_START_DOCSTRING = r"""
|
ELECTRA_START_DOCSTRING = r"""
|
||||||
|
|||||||
@@ -44,7 +44,6 @@ if is_flax_available():
|
|||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jaxlib.xla_extension as jax_xla
|
|
||||||
from flax.core.frozen_dict import unfreeze
|
from flax.core.frozen_dict import unfreeze
|
||||||
from flax.traverse_util import flatten_dict
|
from flax.traverse_util import flatten_dict
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -127,7 +126,7 @@ class FlaxModelTesterMixin:
|
|||||||
if "ForMultipleChoice" in model_class.__name__:
|
if "ForMultipleChoice" in model_class.__name__:
|
||||||
inputs_dict = {
|
inputs_dict = {
|
||||||
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
|
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
|
||||||
if isinstance(v, (jax_xla.DeviceArray, np.ndarray))
|
if isinstance(v, (jnp.ndarray, np.ndarray))
|
||||||
else v
|
else v
|
||||||
for k, v in inputs_dict.items()
|
for k, v in inputs_dict.items()
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user