Add TFViTModel (#13778)

* Start the work for TFViTModel

* Convert to TF code - need to check in the follow up commits

* Clean up model code

* Expose TFViTModel

* make style

* make quality

* Add test

* make style & quality

* Fix some imports

* fix wrong usage - *kwargs => ** kwargs

* Fix Conv2D weight loading (PT->TF) issue

* Add tests for images with different sizes + fix model

* Fix some common tests for TFViTModel

* Use inputs instead of input_ids in test_compile_tf_model

* Add a comment about transpose and Conv2D in convert_tf_weight_name_to_pt_weight_name

* Avoid transpose in TFViT call

* Fix Conv2D issue in load_tf2_weights_in_pytorch_model

* Use tf.keras.layers.Conv2D instead of tf.nn.conv2d

* Using simpler heuristic to detect Conv2D layer

* Change convert_tf_weight_name_to_pt_weight_name to return TransposeType

* Check tf_weight_shape is not None before using it

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix missing comma

* fix input dtype

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2021-11-09 13:54:37 +01:00
committed by GitHub
parent 6326aa4bf0
commit be4a6c64dc
13 changed files with 1420 additions and 19 deletions

View File

@@ -223,6 +223,13 @@ TFAutoModelForCausalLM
:members:
TFAutoModelForImageClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFAutoModelForImageClassification
:members:
TFAutoModelForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@@ -120,6 +120,20 @@ ViTForImageClassification
:members: forward
TFViTModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFViTModel
:members: call
TFViTForImageClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFViTForImageClassification
:members: call
FlaxVitModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~