From be4a6c64dcd67e64cedb5b5eb220bc7a8b128121 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 9 Nov 2021 13:54:37 +0100 Subject: [PATCH] Add TFViTModel (#13778) * Start the work for TFViTModel * Convert to TF code - need to check in the follow up commits * Clean up model code * Expose TFViTModel * make style * make quality * Add test * make style & quality * Fix some imports * fix wrong usage - *kwargs => ** kwargs * Fix Conv2D weight loading (PT->TF) issue * Add tests for images with different sizes + fix model * Fix some common tests for TFViTModel * Use inputs instead of input_ids in test_compile_tf_model * Add a comment about transpose and Conv2D in convert_tf_weight_name_to_pt_weight_name * Avoid transpose in TFViT call * Fix Conv2D issue in load_tf2_weights_in_pytorch_model * Use tf.keras.layers.Conv2D instead of tf.nn.conv2d * Using simpler heuristic to detect Conv2D layer * Change convert_tf_weight_name_to_pt_weight_name to return TransposeType * Check tf_weight_shape is not None before using it * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix missing comma * fix input dtype Co-authored-by: ydshieh Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- docs/source/index.rst | 2 +- docs/source/model_doc/auto.rst | 7 + docs/source/model_doc/vit.rst | 14 + src/transformers/__init__.py | 12 + src/transformers/modeling_tf_pytorch_utils.py | 45 +- src/transformers/models/auto/__init__.py | 4 + .../models/auto/modeling_tf_auto.py | 18 + src/transformers/models/vit/__init__.py | 11 +- .../models/vit/modeling_tf_vit.py | 859 ++++++++++++++++++ src/transformers/utils/dummy_tf_objects.py | 35 + tests/test_modeling_common.py | 4 + tests/test_modeling_tf_common.py | 39 +- tests/test_modeling_tf_vit.py | 389 ++++++++ 13 files changed, 1420 insertions(+), 19 deletions(-) create mode 100644 src/transformers/models/vit/modeling_tf_vit.py create mode 100644 tests/test_modeling_tf_vit.py diff --git a/docs/source/index.rst b/docs/source/index.rst index 357ae26f89..cea20ecf7f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -503,7 +503,7 @@ Flax), PyTorch, and/or TensorFlow. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ -| ViT | ❌ | ❌ | ✅ | ❌ | ✅ | +| ViT | ❌ | ❌ | ✅ | ✅ | ✅ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ diff --git a/docs/source/model_doc/auto.rst b/docs/source/model_doc/auto.rst index fe61e6557d..ef28d74207 100644 --- a/docs/source/model_doc/auto.rst +++ b/docs/source/model_doc/auto.rst @@ -223,6 +223,13 @@ TFAutoModelForCausalLM :members: +TFAutoModelForImageClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFAutoModelForImageClassification + :members: + + TFAutoModelForMaskedLM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/vit.rst b/docs/source/model_doc/vit.rst index f06da21275..c52de0c946 100644 --- a/docs/source/model_doc/vit.rst +++ b/docs/source/model_doc/vit.rst @@ -120,6 +120,20 @@ ViTForImageClassification :members: forward +TFViTModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFViTModel + :members: call + + +TFViTForImageClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFViTForImageClassification + :members: call + + FlaxVitModel ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index acf93b5dd7..6af1df7469 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1396,6 +1396,7 @@ if is_tf_available(): _import_structure["models.auto"].extend( [ "TF_MODEL_FOR_CAUSAL_LM_MAPPING", + "TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", "TF_MODEL_FOR_MASKED_LM_MAPPING", "TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", @@ -1408,6 +1409,7 @@ if is_tf_available(): "TF_MODEL_WITH_LM_HEAD_MAPPING", "TFAutoModel", "TFAutoModelForCausalLM", + "TFAutoModelForImageClassification", "TFAutoModelForMaskedLM", "TFAutoModelForMultipleChoice", "TFAutoModelForPreTraining", @@ -1734,6 +1736,13 @@ if is_tf_available(): "TFTransfoXLPreTrainedModel", ] ) + _import_structure["models.vit"].extend( + [ + "TFViTForImageClassification", + "TFViTModel", + "TFViTPreTrainedModel", + ] + ) _import_structure["models.wav2vec2"].extend( [ "TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -3133,6 +3142,7 @@ if TYPE_CHECKING: ) from .models.auto import ( TF_MODEL_FOR_CAUSAL_LM_MAPPING, + TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, @@ -3145,6 +3155,7 @@ if TYPE_CHECKING: TF_MODEL_WITH_LM_HEAD_MAPPING, TFAutoModel, TFAutoModelForCausalLM, + TFAutoModelForImageClassification, TFAutoModelForMaskedLM, TFAutoModelForMultipleChoice, TFAutoModelForPreTraining, @@ -3406,6 +3417,7 @@ if TYPE_CHECKING: TFTransfoXLModel, TFTransfoXLPreTrainedModel, ) + from .models.vit import TFViTForImageClassification, TFViTModel, TFViTPreTrainedModel from .models.wav2vec2 import ( TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, TFWav2Vec2ForCTC, diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index db3d1cf705..9ed2a0a1cd 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -21,13 +21,24 @@ import re import numpy +from .file_utils import ExplicitEnum from .utils import logging logger = logging.get_logger(__name__) -def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=""): +class TransposeType(ExplicitEnum): + """ + Possible ... + """ + + NO = "no" + SIMPLE = "simple" + CONV2D = "conv2d" + + +def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="", tf_weight_shape=None): """ Convert a TF 2.0 model variable name in a pytorch model weight name. @@ -39,8 +50,8 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="") return tuple with: - pytorch model weight name - - transpose: boolean indicating whether TF2.0 and PyTorch weights matrices are transposed with regards to each - other + - transpose: `TransposeType` member indicating whether and how TF2.0 and PyTorch weights matrices should be + transposed with regards to each other """ tf_name = tf_name.replace(":0", "") # device ids tf_name = re.sub( @@ -56,11 +67,17 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="") tf_name = tf_name[1:] # Remove level zero # When should we transpose the weights - transpose = bool( + if tf_name[-1] == "kernel" and tf_weight_shape is not None and tf_weight_shape.rank == 4: + # A simple heuristic to detect conv layer using weight array shape + transpose = TransposeType.CONV2D + elif bool( tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"] or "emb_projs" in tf_name or "out_projs" in tf_name - ) + ): + transpose = TransposeType.SIMPLE + else: + transpose = TransposeType.NO # Convert standard TF2.0 names in PyTorch names if tf_name[-1] == "kernel" or tf_name[-1] == "embeddings" or tf_name[-1] == "gamma": @@ -165,7 +182,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a for symbolic_weight in symbolic_weights: sw_name = symbolic_weight.name name, transpose = convert_tf_weight_name_to_pt_weight_name( - sw_name, start_prefix_to_remove=start_prefix_to_remove + sw_name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=symbolic_weight.shape ) # Find associated numpy array in pytorch model state dict @@ -182,7 +199,12 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a array = pt_state_dict[name].numpy() - if transpose: + if transpose is TransposeType.CONV2D: + # Conv2D weight: + # PT: (num_out_channel, num_in_channel, kernel[0], kernel[1]) + # -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel) + array = numpy.transpose(array, axes=(2, 3, 1, 0)) + elif transpose is TransposeType.SIMPLE: array = numpy.transpose(array) if len(symbolic_weight.shape) < len(array.shape): @@ -326,7 +348,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F tf_weights_map = {} for tf_weight in tf_weights: pt_name, transpose = convert_tf_weight_name_to_pt_weight_name( - tf_weight.name, start_prefix_to_remove=start_prefix_to_remove + tf_weight.name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=tf_weight.shape ) tf_weights_map[pt_name] = (tf_weight.numpy(), transpose) @@ -350,7 +372,12 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F array, transpose = tf_weights_map[pt_weight_name] - if transpose: + if transpose is TransposeType.CONV2D: + # Conv2D weight: + # TF: (kernel[0], kernel[1], num_in_channel, num_out_channel) + # -> PT: (num_out_channel, num_in_channel, kernel[0], kernel[1]) + array = numpy.transpose(array, axes=(3, 2, 0, 1)) + elif transpose is TransposeType.SIMPLE: array = numpy.transpose(array) if len(pt_weight.shape) < len(array.shape): diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 23c2424172..cf1ca9137a 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -73,6 +73,7 @@ if is_torch_available(): if is_tf_available(): _import_structure["modeling_tf_auto"] = [ "TF_MODEL_FOR_CAUSAL_LM_MAPPING", + "TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", "TF_MODEL_FOR_MASKED_LM_MAPPING", "TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", @@ -85,6 +86,7 @@ if is_tf_available(): "TF_MODEL_WITH_LM_HEAD_MAPPING", "TFAutoModel", "TFAutoModelForCausalLM", + "TFAutoModelForImageClassification", "TFAutoModelForMaskedLM", "TFAutoModelForMultipleChoice", "TFAutoModelForPreTraining", @@ -175,6 +177,7 @@ if TYPE_CHECKING: if is_tf_available(): from .modeling_tf_auto import ( TF_MODEL_FOR_CAUSAL_LM_MAPPING, + TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, @@ -187,6 +190,7 @@ if TYPE_CHECKING: TF_MODEL_WITH_LM_HEAD_MAPPING, TFAutoModel, TFAutoModelForCausalLM, + TFAutoModelForImageClassification, TFAutoModelForMaskedLM, TFAutoModelForMultipleChoice, TFAutoModelForPreTraining, diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index f7ff78e629..0ff9eed872 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -64,6 +64,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict( ("pegasus", "TFPegasusModel"), ("blenderbot", "TFBlenderbotModel"), ("blenderbot-small", "TFBlenderbotSmallModel"), + ("vit", "TFViTModel"), ("wav2vec2", "TFWav2Vec2Model"), ("hubert", "TFHubertModel"), ] @@ -144,6 +145,13 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ] ) +TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Image-classsification + ("vit", "TFViTForImageClassification"), + ] +) + TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( [ # Model for Masked LM mapping @@ -302,6 +310,9 @@ TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES) TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES) TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) +TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES +) TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES) TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES @@ -352,6 +363,13 @@ class TFAutoModelForCausalLM(_BaseAutoModelClass): TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling") +class TFAutoModelForImageClassification(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING + + +AutoModelForImageClassification = auto_class_update(TFAutoModelForImageClassification, head_doc="image classification") + + class TFAutoModelForMaskedLM(_BaseAutoModelClass): _model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING diff --git a/src/transformers/models/vit/__init__.py b/src/transformers/models/vit/__init__.py index a54a26474b..a66edd8948 100644 --- a/src/transformers/models/vit/__init__.py +++ b/src/transformers/models/vit/__init__.py @@ -17,7 +17,7 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...file_utils import _LazyModule, is_flax_available, is_torch_available, is_vision_available +from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available, is_vision_available _import_structure = { @@ -35,6 +35,12 @@ if is_torch_available(): "ViTPreTrainedModel", ] +if is_tf_available(): + _import_structure["modeling_tf_vit"] = [ + "TFViTForImageClassification", + "TFViTModel", + "TFViTPreTrainedModel", + ] if is_flax_available(): _import_structure["modeling_flax_vit"] = [ @@ -57,6 +63,9 @@ if TYPE_CHECKING: ViTPreTrainedModel, ) + if is_tf_available(): + from .modeling_tf_vit import TFViTForImageClassification, TFViTModel, TFViTPreTrainedModel + if is_flax_available(): from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel diff --git a/src/transformers/models/vit/modeling_tf_vit.py b/src/transformers/models/vit/modeling_tf_vit.py new file mode 100644 index 0000000000..54030b2ee3 --- /dev/null +++ b/src/transformers/models/vit/modeling_tf_vit.py @@ -0,0 +1,859 @@ +# coding=utf-8 +# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" TF 2.0 ViT model. """ + + +import collections.abc +import math +from typing import Dict, Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings +from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput +from ...modeling_tf_utils import ( + TFModelInputType, + TFPreTrainedModel, + TFSequenceClassificationLoss, + get_initializer, + input_processing, + keras_serializable, + shape_list, +) +from ...utils import logging +from .configuration_vit import ViTConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "ViTConfig" +_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224" + + +# Inspired by +# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py +# From PyTorch internals +def to_2tuple(x): + if isinstance(x, collections.abc.Iterable): + return x + return (x, x) + + +# Based on timm implementation, which can be found here: +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + + +class TFViTEmbeddings(tf.keras.layers.Layer): + """ + Construct the CLS token, position and patch embeddings. + + """ + + def __init__(self, config: ViTConfig, **kwargs): + super().__init__(**kwargs) + + self.patch_embeddings = TFPatchEmbeddings(config, name="patch_embeddings") + self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def build(self, input_shape: tf.TensorShape): + + num_patches = self.patch_embeddings.num_patches + self.cls_token = self.add_weight( + shape=(1, 1, self.config.hidden_size), initializer="zeros", trainable=True, name="cls_token" + ) + self.position_embeddings = self.add_weight( + shape=(1, num_patches + 1, self.config.hidden_size), + initializer="zeros", + trainable=True, + name="position_embeddings", + ) + + super().build(input_shape) + + def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + batch_size, seq_len, dim = shape_list(embeddings) + npatch = seq_len - 1 + + _, N, _ = shape_list(self.position_embeddings) + N -= 1 + + if npatch == N and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + patch_pos_embed = tf.image.resize( + images=tf.reshape(patch_pos_embed, shape=(1, int(math.sqrt(N)), int(math.sqrt(N)), dim)), + size=(h0, w0), + method="bicubic", + ) + + shape = shape_list(patch_pos_embed) + assert h0 == shape[-3] and w0 == shape[-2] + patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim)) + return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1) + + def call( + self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False + ) -> tf.Tensor: + batch_size, num_channels, height, width = shape_list(pixel_values) + embeddings = self.patch_embeddings( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, training=training + ) + + # add the [CLS] token to the embedded patch tokens + cls_tokens = tf.repeat(self.cls_token, repeats=batch_size, axis=0) + embeddings = tf.concat((cls_tokens, embeddings), axis=1) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings, training=training) + + return embeddings + + +# Based on timm implementation, which can be found here: +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +class TFPatchEmbeddings(tf.keras.layers.Layer): + """ + Image to Patch Embedding. + """ + + def __init__(self, config: ViTConfig, **kwargs): + super().__init__(**kwargs) + image_size = to_2tuple(config.image_size) + patch_size = to_2tuple(config.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = num_patches + self.num_channels = config.num_channels + self.embed_dim = config.hidden_size + self.config = config + + self.projection = tf.keras.layers.Conv2D( + filters=self.embed_dim, + kernel_size=patch_size, + strides=self.patch_size, + padding="valid", + data_format="channels_last", + use_bias=True, + kernel_initializer=get_initializer(self.config.initializer_range), + bias_initializer="zeros", + name="projection", + ) + + def call( + self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False + ) -> tf.Tensor: + batch_size, num_channels, height, width = shape_list(pixel_values) + if not interpolate_pos_encoding: + if getattr(height, "numpy", None) and getattr(width, "numpy", None): + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + + # When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format. + # So change the input format from `NCHW` to `NHWC`. + # shape = (batch_size, in_height, in_width, in_channels=num_channels) + pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) + + projection = self.projection(pixel_values) + + # Change the 2D spatial dimensions to a single temporal dimension. + # shape = (batch_size, num_patches, out_channels=embed_dim) + num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0]) + x = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1)) + + return x + + +class TFViTSelfAttention(tf.keras.layers.Layer): + def __init__(self, config: ViTConfig, **kwargs): + super().__init__(**kwargs) + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number " + f"of attention heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.sqrt_att_head_size = math.sqrt(self.attention_head_size) + + self.query = tf.keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" + ) + self.key = tf.keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" + ) + self.value = tf.keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" + ) + self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob) + + def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: + # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] + tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) + + # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] + return tf.transpose(tensor, perm=[0, 2, 1, 3]) + + def call( + self, + hidden_states: tf.Tensor, + head_mask: tf.Tensor, + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + batch_size = shape_list(hidden_states)[0] + mixed_query_layer = self.query(inputs=hidden_states) + mixed_key_layer = self.key(inputs=hidden_states) + mixed_value_layer = self.value(inputs=hidden_states) + query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) + key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) + value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) + + # Take the dot product between "query" and "key" to get the raw attention scores. + # (batch size, num_heads, seq_len_q, seq_len_k) + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) + attention_scores = tf.divide(attention_scores, dk) + + # Normalize the attention scores to probabilities. + attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(inputs=attention_probs, training=training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = tf.multiply(attention_probs, head_mask) + + attention_output = tf.matmul(attention_probs, value_layer) + attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) + + # (batch_size, seq_len_q, all_head_size) + attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) + outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + + return outputs + + +class TFViTSelfOutput(tf.keras.layers.Layer): + """ + The residual connection is defined in TFViTLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: ViTConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = tf.keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + + return hidden_states + + +class TFViTAttention(tf.keras.layers.Layer): + def __init__(self, config: ViTConfig, **kwargs): + super().__init__(**kwargs) + + self.self_attention = TFViTSelfAttention(config, name="attention") + self.dense_output = TFViTSelfOutput(config, name="output") + + def prune_heads(self, heads): + raise NotImplementedError + + def call( + self, + input_tensor: tf.Tensor, + head_mask: tf.Tensor, + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + self_outputs = self.self_attention( + hidden_states=input_tensor, head_mask=head_mask, output_attentions=output_attentions, training=training + ) + attention_output = self.dense_output( + hidden_states=self_outputs[0], input_tensor=input_tensor, training=training + ) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + + return outputs + + +class TFViTIntermediate(tf.keras.layers.Layer): + def __init__(self, config: ViTConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = tf.keras.layers.Dense( + units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + else: + self.intermediate_act_fn = config.hidden_act + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +class TFViTOutput(tf.keras.layers.Layer): + def __init__(self, config: ViTConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = tf.keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = hidden_states + input_tensor + + return hidden_states + + +class TFViTLayer(tf.keras.layers.Layer): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config: ViTConfig, **kwargs): + super().__init__(**kwargs) + + self.attention = TFViTAttention(config, name="attention") + self.intermediate = TFViTIntermediate(config, name="intermediate") + self.vit_output = TFViTOutput(config, name="output") + + self.layernorm_before = tf.keras.layers.LayerNormalization( + epsilon=config.layer_norm_eps, name="layernorm_before" + ) + self.layernorm_after = tf.keras.layers.LayerNormalization( + epsilon=config.layer_norm_eps, name="layernorm_after" + ) + + def call( + self, + hidden_states: tf.Tensor, + head_mask: tf.Tensor, + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + attention_outputs = self.attention( + # in ViT, layernorm is applied before self-attention + input_tensor=self.layernorm_before(inputs=hidden_states), + head_mask=head_mask, + output_attentions=output_attentions, + training=training, + ) + attention_output = attention_outputs[0] + + # first residual connection + hidden_states = attention_output + hidden_states + + # in ViT, layernorm is also applied after self-attention + layer_output = self.layernorm_after(inputs=hidden_states) + + intermediate_output = self.intermediate(hidden_states=layer_output) + + # second residual connection is done here + layer_output = self.vit_output( + hidden_states=intermediate_output, input_tensor=hidden_states, training=training + ) + outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + + return outputs + + +class TFViTEncoder(tf.keras.layers.Layer): + def __init__(self, config: ViTConfig, **kwargs): + super().__init__(**kwargs) + + self.layer = [TFViTLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] + + def call( + self, + hidden_states: tf.Tensor, + head_mask: tf.Tensor, + output_attentions: bool, + output_hidden_states: bool, + return_dict: bool, + training: bool = False, + ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states=hidden_states, + head_mask=head_mask[i], + output_attentions=output_attentions, + training=training, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + + return TFBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +@keras_serializable +class TFViTMainLayer(tf.keras.layers.Layer): + config_class = ViTConfig + + def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, **kwargs): + super().__init__(**kwargs) + + self.config = config + + self.embeddings = TFViTEmbeddings(config, name="embeddings") + self.encoder = TFViTEncoder(config, name="encoder") + self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") + self.pooler = TFViTPooler(config, name="pooler") if add_pooling_layer else None + + def get_input_embeddings(self) -> tf.keras.layers.Layer: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError + + def call( + self, + pixel_values: Optional[TFModelInputType] = None, + head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + **kwargs, + ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: + inputs = input_processing( + func=self.call, + config=self.config, + input_ids=pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + training=training, + kwargs_call=kwargs, + ) + + if "input_ids" in inputs: + inputs["pixel_values"] = inputs.pop("input_ids") + + if inputs["pixel_values"] is None: + raise ValueError("You have to specify pixel_values") + + embedding_output = self.embeddings( + pixel_values=inputs["pixel_values"], + interpolate_pos_encoding=inputs["interpolate_pos_encoding"], + training=inputs["training"], + ) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if inputs["head_mask"] is not None: + raise NotImplementedError + else: + inputs["head_mask"] = [None] * self.config.num_hidden_layers + + encoder_outputs = self.encoder( + hidden_states=embedding_output, + head_mask=inputs["head_mask"], + output_attentions=inputs["output_attentions"], + output_hidden_states=inputs["output_hidden_states"], + return_dict=inputs["return_dict"], + training=inputs["training"], + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(inputs=sequence_output) + pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None + + if not inputs["return_dict"]: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return TFBaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class TFViTPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ViTConfig + base_model_prefix = "vit" + + @property + def dummy_inputs(self) -> Dict[str, tf.Tensor]: + """ + Dummy inputs to build the network. + + Returns: + :obj:`Dict[str, tf.Tensor]`: The dummy inputs. + """ + VISION_DUMMY_INPUTS = tf.random.uniform( + shape=(3, self.config.num_channels, self.config.image_size, self.config.image_size), dtype=tf.float32 + ) + return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)} + + @tf.function( + input_signature=[ + { + "pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"), + } + ] + ) + def serving(self, inputs): + """ + Method used for serving the model. + + Args: + inputs (:obj:`Dict[str, tf.Tensor]`): + The input of the saved model as a dictionary of tensors. + """ + output = self.call(inputs) + + return self.serving_output(output) + + +VIT_START_DOCSTRING = r""" + + This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the + generic methods the library implements for all its model (such as downloading or saving, resizing the input + embeddings, pruning heads etc.) + + This model is also a `tf.keras.Model `__ subclass. Use + it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage + and behavior. + + .. note:: + + TF 2.0 models accepts two formats as inputs: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional arguments. + + This second option is useful when using :meth:`tf.keras.Model.fit` method which currently requires having all + the tensors in the first argument of the model call function: :obj:`model(inputs)`. + + Args: + config (:class:`~transformers.ViTConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.TFPreTrainedModel.from_pretrained` method to load the + model weights. +""" + +VIT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (:obj:`np.ndarray`, :obj:`tf.Tensor`, :obj:`List[tf.Tensor]` :obj:`Dict[str, tf.Tensor]` or :obj:`Dict[str, np.ndarray]` and each example must have the shape :obj:`(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using :class:`~transformers.ViTFeatureExtractor`. See + :meth:`transformers.ViTFeatureExtractor.__call__` for details. + + head_mask (:obj:`np.ndarray` or :obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + interpolate_pos_encoding (:obj:`bool`, `optional`): + Whether to interpolate the pre-trained position encodings. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. This + argument can be used in eager mode, in graph mode the value will always be set to True. + training (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare ViT Model transformer outputting raw hidden-states without any specific head on top.", + VIT_START_DOCSTRING, +) +class TFViTModel(TFViTPreTrainedModel): + def __init__(self, config: ViTConfig, *inputs, add_pooling_layer=True, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.vit = TFViTMainLayer(config, add_pooling_layer=add_pooling_layer, name="vit") + + @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) + def call( + self, + pixel_values: Optional[TFModelInputType] = None, + head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + **kwargs, + ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: + r""" + Returns: + + Examples:: + + >>> from transformers import ViTFeatureExtractor, TFViTModel + >>> from PIL import Image + >>> import requests + + >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k') + >>> model = TFViTModel.from_pretrained('google/vit-base-patch16-224-in21k') + + >>> inputs = feature_extractor(images=image, return_tensors="tf") + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + """ + inputs = input_processing( + func=self.call, + config=self.config, + input_ids=pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + training=training, + kwargs_call=kwargs, + ) + + if "input_ids" in inputs: + inputs["pixel_values"] = inputs.pop("input_ids") + + outputs = self.vit( + pixel_values=inputs["pixel_values"], + head_mask=inputs["head_mask"], + output_attentions=inputs["output_attentions"], + output_hidden_states=inputs["output_hidden_states"], + interpolate_pos_encoding=inputs["interpolate_pos_encoding"], + return_dict=inputs["return_dict"], + training=inputs["training"], + ) + + return outputs + + def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling: + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + + return TFBaseModelOutputWithPooling( + last_hidden_state=output.last_hidden_state, + pooler_output=output.pooler_output, + hidden_states=hs, + attentions=attns, + ) + + +class TFViTPooler(tf.keras.layers.Layer): + def __init__(self, config: ViTConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = tf.keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(inputs=first_token_tensor) + + return pooled_output + + +@add_start_docstrings( + """ + ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + """, + VIT_START_DOCSTRING, +) +class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: ViTConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + self.vit = TFViTMainLayer(config, add_pooling_layer=False, name="vit") + + # Classifier head + self.classifier = tf.keras.layers.Dense( + units=config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + name="classifier", + ) + + @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + pixel_values: Optional[TFModelInputType] = None, + head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[Union[np.ndarray, tf.Tensor]] = None, + training: Optional[bool] = False, + **kwargs, + ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the image classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples:: + + >>> from transformers import ViTFeatureExtractor, TFViTForImageClassification + >>> from PIL import Image + >>> import requests + + >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224') + >>> model = TFViTForImageClassification.from_pretrained('google/vit-base-patch16-224') + + >>> inputs = feature_extractor(images=image, return_tensors="tf") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0] + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + """ + inputs = input_processing( + func=self.call, + config=self.config, + input_ids=pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + labels=labels, + training=training, + kwargs_call=kwargs, + ) + + if "input_ids" in inputs: + inputs["pixel_values"] = inputs.pop("input_ids") + + outputs = self.vit( + pixel_values=inputs["pixel_values"], + head_mask=inputs["head_mask"], + output_attentions=inputs["output_attentions"], + output_hidden_states=inputs["output_hidden_states"], + interpolate_pos_encoding=inputs["interpolate_pos_encoding"], + return_dict=inputs["return_dict"], + training=inputs["training"], + ) + sequence_output = outputs[0] + logits = self.classifier(inputs=sequence_output[:, 0, :]) + loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=logits) + + if not inputs["return_dict"]: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput: + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + + return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 01d3dc7db6..c3762125b5 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -176,6 +176,9 @@ class TFAlbertPreTrainedModel: TF_MODEL_FOR_CAUSAL_LM_MAPPING = None +TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None + + TF_MODEL_FOR_MASKED_LM_MAPPING = None @@ -224,6 +227,15 @@ class TFAutoModelForCausalLM: requires_backends(cls, ["tf"]) +class TFAutoModelForImageClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["tf"]) + + class TFAutoModelForMaskedLM: def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) @@ -1971,6 +1983,29 @@ class TFTransfoXLPreTrainedModel: requires_backends(cls, ["tf"]) +class TFViTForImageClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + +class TFViTModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["tf"]) + + +class TFViTPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["tf"]) + + TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 6da867e2a7..af002cea63 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1476,6 +1476,8 @@ class ModelTesterMixin: tf_inputs_dict[key] = tensor elif key == "input_values": tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32) + elif key == "pixel_values": + tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32) else: tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32) @@ -1525,6 +1527,8 @@ class ModelTesterMixin: tf_inputs_dict[key] = tf.convert_to_tensor(tensor, dtype=tf.int32) elif key == "input_values": tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32) + elif key == "pixel_values": + tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32) else: tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 64ca24eeb6..2448ae9c90 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -49,6 +49,7 @@ if is_tf_available(): from transformers import ( TF_MODEL_FOR_CAUSAL_LM_MAPPING, + TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, @@ -126,7 +127,10 @@ class TFModelTesterMixin: elif model_class in get_values(TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING): inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) - elif model_class in get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING): + elif model_class in [ + *get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING), + *get_values(TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING), + ]: inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) elif model_class in get_values(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING): inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) @@ -460,6 +464,8 @@ class TFModelTesterMixin: pt_inputs_dict[name] = key elif name == "input_values": pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) + elif name == "pixel_values": + pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) else: pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long) @@ -504,6 +510,8 @@ class TFModelTesterMixin: pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long) elif name == "input_values": pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) + elif name == "pixel_values": + pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) else: pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long) # need to rename encoder-decoder "inputs" for PyTorch @@ -605,7 +613,7 @@ class TFModelTesterMixin: for model_class in self.all_model_classes: if self.is_encoder_decoder: - input_ids = { + inputs = { "decoder_input_ids": tf.keras.Input( batch_shape=(2, max_input), name="decoder_input_ids", @@ -613,10 +621,22 @@ class TFModelTesterMixin: ), "input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"), } + # TODO: A better way to handle vision models + elif model_class.__name__ in ["TFViTModel", "TFViTForImageClassification"]: + inputs = tf.keras.Input( + batch_shape=( + 3, + self.model_tester.num_channels, + self.model_tester.image_size, + self.model_tester.image_size, + ), + name="pixel_values", + dtype="float32", + ) elif model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING): - input_ids = tf.keras.Input(batch_shape=(4, 2, max_input), name="input_ids", dtype="int32") + inputs = tf.keras.Input(batch_shape=(4, 2, max_input), name="input_ids", dtype="int32") else: - input_ids = tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32") + inputs = tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32") # Prepare our model model = model_class(config) @@ -626,14 +646,14 @@ class TFModelTesterMixin: model.save_pretrained(tmpdirname, saved_model=False) model = model_class.from_pretrained(tmpdirname) - outputs_dict = model(input_ids) + outputs_dict = model(inputs) hidden_states = outputs_dict[0] # Add a dense layer on top to test integration with other keras modules outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states) # Compile extended model - extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs]) + extended_model = tf.keras.Model(inputs=[inputs], outputs=[outputs]) extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric]) def test_keyword_and_dict_args(self): @@ -647,6 +667,8 @@ class TFModelTesterMixin: inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) input_ids = inputs_keywords.pop("input_ids", None) + if input_ids is None: + input_ids = inputs_keywords.pop("pixel_values", None) outputs_keywords = model(input_ids, **inputs_keywords) output_dict = outputs_dict[0].numpy() output_keywords = outputs_keywords[0].numpy() @@ -1236,7 +1258,8 @@ class TFModelTesterMixin: # Test that model correctly compute the loss with kwargs prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) - input_ids = prepared_for_class.pop("input_ids") + input_name = "input_ids" if "input_ids" in prepared_for_class else "pixel_values" + input_ids = prepared_for_class.pop(input_name) loss = model(input_ids, **prepared_for_class)[0] self.assertEqual(loss.shape, [loss_size]) @@ -1255,7 +1278,7 @@ class TFModelTesterMixin: signature_names = list(signature.keys()) # Create a dictionary holding the location of the tensors in the tuple - tuple_index_mapping = {0: "input_ids"} + tuple_index_mapping = {0: input_name} for label_key in label_keys: label_key_index = signature_names.index(label_key) tuple_index_mapping[label_key_index] = label_key diff --git a/tests/test_modeling_tf_vit.py b/tests/test_modeling_tf_vit.py new file mode 100644 index 0000000000..1a83b2824a --- /dev/null +++ b/tests/test_modeling_tf_vit.py @@ -0,0 +1,389 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the TensorFlow ViT model. """ + + +import inspect +import os +import tempfile +import unittest + +from transformers import ViTConfig +from transformers.file_utils import cached_property, is_tf_available, is_vision_available +from transformers.testing_utils import require_tf, require_vision, slow, tooslow + +from .test_configuration_common import ConfigTester +from .test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor + + +if is_tf_available(): + import tensorflow as tf + + from transformers import TFViTForImageClassification, TFViTModel + from transformers.models.vit.modeling_tf_vit import to_2tuple + + +if is_vision_available(): + from PIL import Image + + from transformers import ViTFeatureExtractor + + +class TFViTModelTester: + def __init__( + self, + parent, + batch_size=13, + image_size=30, + patch_size=2, + num_channels=3, + is_training=True, + use_labels=True, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + type_sequence_label_size=10, + initializer_range=0.02, + num_labels=3, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.is_training = is_training + self.use_labels = use_labels + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.scope = scope + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + + config = self.get_config() + + return config, pixel_values, labels + + def get_config(self): + return ViTConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + is_decoder=False, + initializer_range=self.initializer_range, + ) + + def create_and_check_model(self, config, pixel_values, labels): + model = TFViTModel(config=config) + result = model(pixel_values, training=False) + # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + image_size = to_2tuple(self.image_size) + patch_size = to_2tuple(self.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) + + # Test with an image with different size than the one specified in config. + image_size = self.image_size // 2 + pixel_values = pixel_values[:, :, :image_size, :image_size] + result = model(pixel_values, interpolate_pos_encoding=True, training=False) + # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + image_size = to_2tuple(image_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) + + def create_and_check_for_image_classification(self, config, pixel_values, labels): + config.num_labels = self.type_sequence_label_size + model = TFViTForImageClassification(config) + result = model(pixel_values, labels=labels, training=False) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + + # Test with an image with different size than the one specified in config. + image_size = self.image_size // 2 + pixel_values = pixel_values[:, :, :image_size, :image_size] + result = model(pixel_values, interpolate_pos_encoding=True, training=False) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, labels = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_tf +class TFViTModelTest(TFModelTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_tf_common.py, as ViT does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (TFViTModel, TFViTForImageClassification) if is_tf_available() else () + + test_resize_embeddings = False + test_head_masking = False + test_onnx = False + + def setUp(self): + self.model_tester = TFViTModelTester(self) + self.config_tester = ConfigTester(self, config_class=ViTConfig, has_text_modality=False, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_inputs_embeds(self): + # ViT does not use inputs_embeds + pass + + def test_graph_mode_with_inputs_embeds(self): + # ViT does not use inputs_embeds + pass + + def test_model_common_attributes(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (tf.keras.layers.Layer)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, tf.keras.layers.Layer)) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.call) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + # overwrite from common since `encoder_seq_length` and `encoder_key_length` are calculated + # in a different way than in text models. + @tooslow + def test_saved_model_creation_extended(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = True + + if hasattr(config, "use_cache"): + config.use_cache = True + + # in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) + image_size = to_2tuple(self.model_tester.image_size) + patch_size = to_2tuple(self.model_tester.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + seq_len = num_patches + 1 + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + + for model_class in self.all_model_classes: + class_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + num_out = len(model(class_inputs_dict)) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname, saved_model=True) + saved_model_dir = os.path.join(tmpdirname, "saved_model", "1") + model = tf.keras.models.load_model(saved_model_dir) + outputs = model(class_inputs_dict) + + if self.is_encoder_decoder: + output_hidden_states = outputs["encoder_hidden_states"] + output_attentions = outputs["encoder_attentions"] + else: + output_hidden_states = outputs["hidden_states"] + output_attentions = outputs["attentions"] + + self.assertEqual(len(outputs), num_out) + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 + ) + + self.assertEqual(len(output_hidden_states), expected_num_layers) + self.assertListEqual( + list(output_hidden_states[0].shape[-2:]), + [seq_len, self.model_tester.hidden_size], + ) + + self.assertEqual(len(output_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(output_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + # in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) + image_size = to_2tuple(self.model_tester.image_size) + patch_size = to_2tuple(self.model_tester.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + seq_len = num_patches + 1 + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + elif self.is_encoder_decoder: + added_hidden_states = 2 + else: + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + # ViT has a different seq_length + image_size = to_2tuple(self.model_tester.image_size) + patch_size = to_2tuple(self.model_tester.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + seq_length = num_patches + 1 + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [seq_length, self.model_tester.hidden_size], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + + model = TFViTModel.from_pretrained("google/vit-base-patch16-224", from_pt=True) + self.assertIsNotNone(model) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_vision +class TFViTModelIntegrationTest(unittest.TestCase): + @cached_property + def default_feature_extractor(self): + return ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224") if is_vision_available() else None + + @slow + def test_inference_image_classification_head(self): + model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224", from_pt=True) + + feature_extractor = self.default_feature_extractor + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="tf") + + # forward pass + outputs = model(**inputs) + + # verify the logits + expected_shape = tf.TensorShape((1, 1000)) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_slice = tf.constant([-0.2744, 0.8215, -0.0836]) + + tf.debugging.assert_near(outputs.logits[0, :3], expected_slice, atol=1e-4)