Flax VisionTransformer (#11951)

* adding vit for flax

* added test for Flax-vit and some bug-fixes

* overrided methods where variable changes were necessary for flax_vit test

* added FlaxViTForImageClassification for test

* Update src/transformers/models/vit/modeling_flax_vit.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* made changes suggested in PR

* Adding jax-vit models for autoimport

* swapping num_channels and height,width dimension

* fixing the docstring for torch-like inputs for VIT

* add model to main init

* add docs

* doc, fix-copies

* docstrings

* small test fixes

* fix docs

* fix docstr

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* style

Co-authored-by: jayendra <jayendra@infocusp.in>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Jayendra
2021-06-10 21:17:13 +05:30
committed by GitHub
parent 0eaeae2e36
commit 9a9314f6d9
9 changed files with 903 additions and 3 deletions

View File

@@ -395,7 +395,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| ViT | ❌ | ❌ | ✅ | ❌ | |
| ViT | ❌ | ❌ | ✅ | ❌ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+

View File

@@ -101,3 +101,18 @@ ViTForImageClassification
.. autoclass:: transformers.ViTForImageClassification
:members: forward
FlaxVitModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxViTModel
:members: __call__
FlaxViTForImageClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxViTForImageClassification
:members: __call__