Add TFViTModel (#13778)
* Start the work for TFViTModel * Convert to TF code - need to check in the follow up commits * Clean up model code * Expose TFViTModel * make style * make quality * Add test * make style & quality * Fix some imports * fix wrong usage - *kwargs => ** kwargs * Fix Conv2D weight loading (PT->TF) issue * Add tests for images with different sizes + fix model * Fix some common tests for TFViTModel * Use inputs instead of input_ids in test_compile_tf_model * Add a comment about transpose and Conv2D in convert_tf_weight_name_to_pt_weight_name * Avoid transpose in TFViT call * Fix Conv2D issue in load_tf2_weights_in_pytorch_model * Use tf.keras.layers.Conv2D instead of tf.nn.conv2d * Using simpler heuristic to detect Conv2D layer * Change convert_tf_weight_name_to_pt_weight_name to return TransposeType * Check tf_weight_shape is not None before using it * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix missing comma * fix input dtype Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -503,7 +503,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| ViT | ❌ | ❌ | ✅ | ❌ | ✅ |
|
||||
| ViT | ❌ | ❌ | ✅ | ✅ | ✅ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
|
||||
@@ -223,6 +223,13 @@ TFAutoModelForCausalLM
|
||||
:members:
|
||||
|
||||
|
||||
TFAutoModelForImageClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFAutoModelForImageClassification
|
||||
:members:
|
||||
|
||||
|
||||
TFAutoModelForMaskedLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -120,6 +120,20 @@ ViTForImageClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
TFViTModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFViTModel
|
||||
:members: call
|
||||
|
||||
|
||||
TFViTForImageClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFViTForImageClassification
|
||||
:members: call
|
||||
|
||||
|
||||
FlaxVitModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
Reference in New Issue
Block a user