[Flax] Align jax flax device name (#12987)

* [Flax] Align device name in docs

* make style

* fix import error
This commit is contained in:
Patrick von Platen
2021-08-04 16:00:09 +02:00
committed by GitHub
parent 07df5578d9
commit da9754a3a0
9 changed files with 365 additions and 387 deletions

View File

@@ -224,7 +224,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
`What are input IDs? <../glossary.html#input-ids>`__
Returns:
text_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The text embeddings
text_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The text embeddings
obtained by applying the projection layer to the pooled output of text model.
"""
if position_ids is None:
@@ -273,7 +273,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
:meth:`transformers.ImageFeatureExtractionMixin.__call__` for details.
Returns:
image_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The image embeddings
image_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The image embeddings
obtained by applying the projection layer to the pooled output of vision model.
"""