[Flax] Align jax flax device name (#12987)
* [Flax] Align device name in docs * make style * fix import error
This commit is contained in:
committed by
GitHub
parent
07df5578d9
commit
da9754a3a0
@@ -224,7 +224,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
|
||||
Returns:
|
||||
text_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The text embeddings
|
||||
text_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The text embeddings
|
||||
obtained by applying the projection layer to the pooled output of text model.
|
||||
"""
|
||||
if position_ids is None:
|
||||
@@ -273,7 +273,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
||||
:meth:`transformers.ImageFeatureExtractionMixin.__call__` for details.
|
||||
|
||||
Returns:
|
||||
image_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The image embeddings
|
||||
image_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The image embeddings
|
||||
obtained by applying the projection layer to the pooled output of vision model.
|
||||
"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user