Add FlaxCLIP (#11883)

* add flax CLIP

* default input_shape

* add tests

* fix test

* fix name

* fix docs

* fix shapes

* attend at least 1 token

* flax conv to torch conv

* return floats

* fix equivalence tests

* fix import

* return attention_weights and update tests

* fix dosctrings

* address patricks comments

* input_shape arg

* add tests for get_image_features and get_text_features methods

* fix tests
This commit is contained in:
Suraj Patil
2021-06-01 09:44:31 +05:30
committed by GitHub
parent cfca638acb
commit ad25fd62bd
13 changed files with 1737 additions and 6 deletions

View File

@@ -152,3 +152,24 @@ CLIPVisionModel
.. autoclass:: transformers.CLIPVisionModel
:members: forward
FlaxCLIPModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxCLIPModel
:members: __call__, get_text_features, get_image_features
FlaxCLIPTextModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxCLIPTextModel
:members: __call__
FlaxCLIPVisionModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxCLIPVisionModel
:members: __call__