Add TFViTModel (#13778)
* Start the work for TFViTModel * Convert to TF code - need to check in the follow up commits * Clean up model code * Expose TFViTModel * make style * make quality * Add test * make style & quality * Fix some imports * fix wrong usage - *kwargs => ** kwargs * Fix Conv2D weight loading (PT->TF) issue * Add tests for images with different sizes + fix model * Fix some common tests for TFViTModel * Use inputs instead of input_ids in test_compile_tf_model * Add a comment about transpose and Conv2D in convert_tf_weight_name_to_pt_weight_name * Avoid transpose in TFViT call * Fix Conv2D issue in load_tf2_weights_in_pytorch_model * Use tf.keras.layers.Conv2D instead of tf.nn.conv2d * Using simpler heuristic to detect Conv2D layer * Change convert_tf_weight_name_to_pt_weight_name to return TransposeType * Check tf_weight_shape is not None before using it * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix missing comma * fix input dtype Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -503,7 +503,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| ViT | ❌ | ❌ | ✅ | ❌ | ✅ |
|
||||
| ViT | ❌ | ❌ | ✅ | ✅ | ✅ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
|
||||
@@ -223,6 +223,13 @@ TFAutoModelForCausalLM
|
||||
:members:
|
||||
|
||||
|
||||
TFAutoModelForImageClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFAutoModelForImageClassification
|
||||
:members:
|
||||
|
||||
|
||||
TFAutoModelForMaskedLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -120,6 +120,20 @@ ViTForImageClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
TFViTModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFViTModel
|
||||
:members: call
|
||||
|
||||
|
||||
TFViTForImageClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFViTForImageClassification
|
||||
:members: call
|
||||
|
||||
|
||||
FlaxVitModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -1396,6 +1396,7 @@ if is_tf_available():
|
||||
_import_structure["models.auto"].extend(
|
||||
[
|
||||
"TF_MODEL_FOR_CAUSAL_LM_MAPPING",
|
||||
"TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
||||
"TF_MODEL_FOR_MASKED_LM_MAPPING",
|
||||
"TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
||||
"TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||
@@ -1408,6 +1409,7 @@ if is_tf_available():
|
||||
"TF_MODEL_WITH_LM_HEAD_MAPPING",
|
||||
"TFAutoModel",
|
||||
"TFAutoModelForCausalLM",
|
||||
"TFAutoModelForImageClassification",
|
||||
"TFAutoModelForMaskedLM",
|
||||
"TFAutoModelForMultipleChoice",
|
||||
"TFAutoModelForPreTraining",
|
||||
@@ -1734,6 +1736,13 @@ if is_tf_available():
|
||||
"TFTransfoXLPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.vit"].extend(
|
||||
[
|
||||
"TFViTForImageClassification",
|
||||
"TFViTModel",
|
||||
"TFViTPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.wav2vec2"].extend(
|
||||
[
|
||||
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@@ -3133,6 +3142,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.auto import (
|
||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_MASKED_LM_MAPPING,
|
||||
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
@@ -3145,6 +3155,7 @@ if TYPE_CHECKING:
|
||||
TF_MODEL_WITH_LM_HEAD_MAPPING,
|
||||
TFAutoModel,
|
||||
TFAutoModelForCausalLM,
|
||||
TFAutoModelForImageClassification,
|
||||
TFAutoModelForMaskedLM,
|
||||
TFAutoModelForMultipleChoice,
|
||||
TFAutoModelForPreTraining,
|
||||
@@ -3406,6 +3417,7 @@ if TYPE_CHECKING:
|
||||
TFTransfoXLModel,
|
||||
TFTransfoXLPreTrainedModel,
|
||||
)
|
||||
from .models.vit import TFViTForImageClassification, TFViTModel, TFViTPreTrainedModel
|
||||
from .models.wav2vec2 import (
|
||||
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFWav2Vec2ForCTC,
|
||||
|
||||
@@ -21,13 +21,24 @@ import re
|
||||
|
||||
import numpy
|
||||
|
||||
from .file_utils import ExplicitEnum
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=""):
|
||||
class TransposeType(ExplicitEnum):
|
||||
"""
|
||||
Possible ...
|
||||
"""
|
||||
|
||||
NO = "no"
|
||||
SIMPLE = "simple"
|
||||
CONV2D = "conv2d"
|
||||
|
||||
|
||||
def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="", tf_weight_shape=None):
|
||||
"""
|
||||
Convert a TF 2.0 model variable name in a pytorch model weight name.
|
||||
|
||||
@@ -39,8 +50,8 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="")
|
||||
return tuple with:
|
||||
|
||||
- pytorch model weight name
|
||||
- transpose: boolean indicating whether TF2.0 and PyTorch weights matrices are transposed with regards to each
|
||||
other
|
||||
- transpose: `TransposeType` member indicating whether and how TF2.0 and PyTorch weights matrices should be
|
||||
transposed with regards to each other
|
||||
"""
|
||||
tf_name = tf_name.replace(":0", "") # device ids
|
||||
tf_name = re.sub(
|
||||
@@ -56,11 +67,17 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="")
|
||||
tf_name = tf_name[1:] # Remove level zero
|
||||
|
||||
# When should we transpose the weights
|
||||
transpose = bool(
|
||||
if tf_name[-1] == "kernel" and tf_weight_shape is not None and tf_weight_shape.rank == 4:
|
||||
# A simple heuristic to detect conv layer using weight array shape
|
||||
transpose = TransposeType.CONV2D
|
||||
elif bool(
|
||||
tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"]
|
||||
or "emb_projs" in tf_name
|
||||
or "out_projs" in tf_name
|
||||
)
|
||||
):
|
||||
transpose = TransposeType.SIMPLE
|
||||
else:
|
||||
transpose = TransposeType.NO
|
||||
|
||||
# Convert standard TF2.0 names in PyTorch names
|
||||
if tf_name[-1] == "kernel" or tf_name[-1] == "embeddings" or tf_name[-1] == "gamma":
|
||||
@@ -165,7 +182,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
||||
for symbolic_weight in symbolic_weights:
|
||||
sw_name = symbolic_weight.name
|
||||
name, transpose = convert_tf_weight_name_to_pt_weight_name(
|
||||
sw_name, start_prefix_to_remove=start_prefix_to_remove
|
||||
sw_name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=symbolic_weight.shape
|
||||
)
|
||||
|
||||
# Find associated numpy array in pytorch model state dict
|
||||
@@ -182,7 +199,12 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
||||
|
||||
array = pt_state_dict[name].numpy()
|
||||
|
||||
if transpose:
|
||||
if transpose is TransposeType.CONV2D:
|
||||
# Conv2D weight:
|
||||
# PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
|
||||
# -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
|
||||
array = numpy.transpose(array, axes=(2, 3, 1, 0))
|
||||
elif transpose is TransposeType.SIMPLE:
|
||||
array = numpy.transpose(array)
|
||||
|
||||
if len(symbolic_weight.shape) < len(array.shape):
|
||||
@@ -326,7 +348,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
|
||||
tf_weights_map = {}
|
||||
for tf_weight in tf_weights:
|
||||
pt_name, transpose = convert_tf_weight_name_to_pt_weight_name(
|
||||
tf_weight.name, start_prefix_to_remove=start_prefix_to_remove
|
||||
tf_weight.name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=tf_weight.shape
|
||||
)
|
||||
tf_weights_map[pt_name] = (tf_weight.numpy(), transpose)
|
||||
|
||||
@@ -350,7 +372,12 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
|
||||
|
||||
array, transpose = tf_weights_map[pt_weight_name]
|
||||
|
||||
if transpose:
|
||||
if transpose is TransposeType.CONV2D:
|
||||
# Conv2D weight:
|
||||
# TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
|
||||
# -> PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
|
||||
array = numpy.transpose(array, axes=(3, 2, 0, 1))
|
||||
elif transpose is TransposeType.SIMPLE:
|
||||
array = numpy.transpose(array)
|
||||
|
||||
if len(pt_weight.shape) < len(array.shape):
|
||||
|
||||
@@ -73,6 +73,7 @@ if is_torch_available():
|
||||
if is_tf_available():
|
||||
_import_structure["modeling_tf_auto"] = [
|
||||
"TF_MODEL_FOR_CAUSAL_LM_MAPPING",
|
||||
"TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
|
||||
"TF_MODEL_FOR_MASKED_LM_MAPPING",
|
||||
"TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
|
||||
"TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
|
||||
@@ -85,6 +86,7 @@ if is_tf_available():
|
||||
"TF_MODEL_WITH_LM_HEAD_MAPPING",
|
||||
"TFAutoModel",
|
||||
"TFAutoModelForCausalLM",
|
||||
"TFAutoModelForImageClassification",
|
||||
"TFAutoModelForMaskedLM",
|
||||
"TFAutoModelForMultipleChoice",
|
||||
"TFAutoModelForPreTraining",
|
||||
@@ -175,6 +177,7 @@ if TYPE_CHECKING:
|
||||
if is_tf_available():
|
||||
from .modeling_tf_auto import (
|
||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_MASKED_LM_MAPPING,
|
||||
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
@@ -187,6 +190,7 @@ if TYPE_CHECKING:
|
||||
TF_MODEL_WITH_LM_HEAD_MAPPING,
|
||||
TFAutoModel,
|
||||
TFAutoModelForCausalLM,
|
||||
TFAutoModelForImageClassification,
|
||||
TFAutoModelForMaskedLM,
|
||||
TFAutoModelForMultipleChoice,
|
||||
TFAutoModelForPreTraining,
|
||||
|
||||
@@ -64,6 +64,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("pegasus", "TFPegasusModel"),
|
||||
("blenderbot", "TFBlenderbotModel"),
|
||||
("blenderbot-small", "TFBlenderbotSmallModel"),
|
||||
("vit", "TFViTModel"),
|
||||
("wav2vec2", "TFWav2Vec2Model"),
|
||||
("hubert", "TFHubertModel"),
|
||||
]
|
||||
@@ -144,6 +145,13 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Image-classsification
|
||||
("vit", "TFViTForImageClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Masked LM mapping
|
||||
@@ -302,6 +310,9 @@ TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES
|
||||
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
||||
TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES)
|
||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
|
||||
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
||||
)
|
||||
TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
||||
@@ -352,6 +363,13 @@ class TFAutoModelForCausalLM(_BaseAutoModelClass):
|
||||
TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling")
|
||||
|
||||
|
||||
class TFAutoModelForImageClassification(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||
|
||||
|
||||
AutoModelForImageClassification = auto_class_update(TFAutoModelForImageClassification, head_doc="image classification")
|
||||
|
||||
|
||||
class TFAutoModelForMaskedLM(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _LazyModule, is_flax_available, is_torch_available, is_vision_available
|
||||
from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available, is_vision_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
@@ -35,6 +35,12 @@ if is_torch_available():
|
||||
"ViTPreTrainedModel",
|
||||
]
|
||||
|
||||
if is_tf_available():
|
||||
_import_structure["modeling_tf_vit"] = [
|
||||
"TFViTForImageClassification",
|
||||
"TFViTModel",
|
||||
"TFViTPreTrainedModel",
|
||||
]
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["modeling_flax_vit"] = [
|
||||
@@ -57,6 +63,9 @@ if TYPE_CHECKING:
|
||||
ViTPreTrainedModel,
|
||||
)
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_vit import TFViTForImageClassification, TFViTModel, TFViTPreTrainedModel
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
|
||||
|
||||
|
||||
859
src/transformers/models/vit/modeling_tf_vit.py
Normal file
859
src/transformers/models/vit/modeling_tf_vit.py
Normal file
@@ -0,0 +1,859 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" TF 2.0 ViT model. """
|
||||
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput
|
||||
from ...modeling_tf_utils import (
|
||||
TFModelInputType,
|
||||
TFPreTrainedModel,
|
||||
TFSequenceClassificationLoss,
|
||||
get_initializer,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
)
|
||||
from ...utils import logging
|
||||
from .configuration_vit import ViTConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "ViTConfig"
|
||||
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224"
|
||||
|
||||
|
||||
# Inspired by
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
|
||||
# From PyTorch internals
|
||||
def to_2tuple(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return (x, x)
|
||||
|
||||
|
||||
# Based on timm implementation, which can be found here:
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
|
||||
|
||||
class TFViTEmbeddings(tf.keras.layers.Layer):
|
||||
"""
|
||||
Construct the CLS token, position and patch embeddings.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config: ViTConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.patch_embeddings = TFPatchEmbeddings(config, name="patch_embeddings")
|
||||
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
||||
self.config = config
|
||||
|
||||
def build(self, input_shape: tf.TensorShape):
|
||||
|
||||
num_patches = self.patch_embeddings.num_patches
|
||||
self.cls_token = self.add_weight(
|
||||
shape=(1, 1, self.config.hidden_size), initializer="zeros", trainable=True, name="cls_token"
|
||||
)
|
||||
self.position_embeddings = self.add_weight(
|
||||
shape=(1, num_patches + 1, self.config.hidden_size),
|
||||
initializer="zeros",
|
||||
trainable=True,
|
||||
name="position_embeddings",
|
||||
)
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
||||
resolution images.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
"""
|
||||
|
||||
batch_size, seq_len, dim = shape_list(embeddings)
|
||||
npatch = seq_len - 1
|
||||
|
||||
_, N, _ = shape_list(self.position_embeddings)
|
||||
N -= 1
|
||||
|
||||
if npatch == N and height == width:
|
||||
return self.position_embeddings
|
||||
class_pos_embed = self.position_embeddings[:, :1]
|
||||
patch_pos_embed = self.position_embeddings[:, 1:]
|
||||
h0 = height // self.config.patch_size
|
||||
w0 = width // self.config.patch_size
|
||||
patch_pos_embed = tf.image.resize(
|
||||
images=tf.reshape(patch_pos_embed, shape=(1, int(math.sqrt(N)), int(math.sqrt(N)), dim)),
|
||||
size=(h0, w0),
|
||||
method="bicubic",
|
||||
)
|
||||
|
||||
shape = shape_list(patch_pos_embed)
|
||||
assert h0 == shape[-3] and w0 == shape[-2]
|
||||
patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
|
||||
return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)
|
||||
|
||||
def call(
|
||||
self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False
|
||||
) -> tf.Tensor:
|
||||
batch_size, num_channels, height, width = shape_list(pixel_values)
|
||||
embeddings = self.patch_embeddings(
|
||||
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, training=training
|
||||
)
|
||||
|
||||
# add the [CLS] token to the embedded patch tokens
|
||||
cls_tokens = tf.repeat(self.cls_token, repeats=batch_size, axis=0)
|
||||
embeddings = tf.concat((cls_tokens, embeddings), axis=1)
|
||||
|
||||
# add positional encoding to each token
|
||||
if interpolate_pos_encoding:
|
||||
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
||||
else:
|
||||
embeddings = embeddings + self.position_embeddings
|
||||
|
||||
embeddings = self.dropout(embeddings, training=training)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
# Based on timm implementation, which can be found here:
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
class TFPatchEmbeddings(tf.keras.layers.Layer):
|
||||
"""
|
||||
Image to Patch Embedding.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ViTConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
image_size = to_2tuple(config.image_size)
|
||||
patch_size = to_2tuple(config.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
self.num_channels = config.num_channels
|
||||
self.embed_dim = config.hidden_size
|
||||
self.config = config
|
||||
|
||||
self.projection = tf.keras.layers.Conv2D(
|
||||
filters=self.embed_dim,
|
||||
kernel_size=patch_size,
|
||||
strides=self.patch_size,
|
||||
padding="valid",
|
||||
data_format="channels_last",
|
||||
use_bias=True,
|
||||
kernel_initializer=get_initializer(self.config.initializer_range),
|
||||
bias_initializer="zeros",
|
||||
name="projection",
|
||||
)
|
||||
|
||||
def call(
|
||||
self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False
|
||||
) -> tf.Tensor:
|
||||
batch_size, num_channels, height, width = shape_list(pixel_values)
|
||||
if not interpolate_pos_encoding:
|
||||
if getattr(height, "numpy", None) and getattr(width, "numpy", None):
|
||||
if height != self.image_size[0] or width != self.image_size[1]:
|
||||
raise ValueError(
|
||||
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
||||
)
|
||||
|
||||
# When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
|
||||
# So change the input format from `NCHW` to `NHWC`.
|
||||
# shape = (batch_size, in_height, in_width, in_channels=num_channels)
|
||||
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
|
||||
|
||||
projection = self.projection(pixel_values)
|
||||
|
||||
# Change the 2D spatial dimensions to a single temporal dimension.
|
||||
# shape = (batch_size, num_patches, out_channels=embed_dim)
|
||||
num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0])
|
||||
x = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TFViTSelfAttention(tf.keras.layers.Layer):
|
||||
def __init__(self, config: ViTConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
f"The hidden size ({config.hidden_size}) is not a multiple of the number "
|
||||
f"of attention heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
|
||||
|
||||
self.query = tf.keras.layers.Dense(
|
||||
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
|
||||
)
|
||||
self.key = tf.keras.layers.Dense(
|
||||
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
|
||||
)
|
||||
self.value = tf.keras.layers.Dense(
|
||||
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
|
||||
)
|
||||
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
|
||||
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
||||
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||
|
||||
# Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
|
||||
return tf.transpose(tensor, perm=[0, 2, 1, 3])
|
||||
|
||||
def call(
|
||||
self,
|
||||
hidden_states: tf.Tensor,
|
||||
head_mask: tf.Tensor,
|
||||
output_attentions: bool,
|
||||
training: bool = False,
|
||||
) -> Tuple[tf.Tensor]:
|
||||
batch_size = shape_list(hidden_states)[0]
|
||||
mixed_query_layer = self.query(inputs=hidden_states)
|
||||
mixed_key_layer = self.key(inputs=hidden_states)
|
||||
mixed_value_layer = self.value(inputs=hidden_states)
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
# (batch size, num_heads, seq_len_q, seq_len_k)
|
||||
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
|
||||
dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
|
||||
attention_scores = tf.divide(attention_scores, dk)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(inputs=attention_probs, training=training)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = tf.multiply(attention_probs, head_mask)
|
||||
|
||||
attention_output = tf.matmul(attention_probs, value_layer)
|
||||
attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
|
||||
|
||||
# (batch_size, seq_len_q, all_head_size)
|
||||
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
|
||||
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class TFViTSelfOutput(tf.keras.layers.Layer):
|
||||
"""
|
||||
The residual connection is defined in TFViTLayer instead of here (as is the case with other models), due to the
|
||||
layernorm applied before each block.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ViTConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
||||
|
||||
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
|
||||
hidden_states = self.dense(inputs=hidden_states)
|
||||
hidden_states = self.dropout(inputs=hidden_states, training=training)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TFViTAttention(tf.keras.layers.Layer):
|
||||
def __init__(self, config: ViTConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.self_attention = TFViTSelfAttention(config, name="attention")
|
||||
self.dense_output = TFViTSelfOutput(config, name="output")
|
||||
|
||||
def prune_heads(self, heads):
|
||||
raise NotImplementedError
|
||||
|
||||
def call(
|
||||
self,
|
||||
input_tensor: tf.Tensor,
|
||||
head_mask: tf.Tensor,
|
||||
output_attentions: bool,
|
||||
training: bool = False,
|
||||
) -> Tuple[tf.Tensor]:
|
||||
self_outputs = self.self_attention(
|
||||
hidden_states=input_tensor, head_mask=head_mask, output_attentions=output_attentions, training=training
|
||||
)
|
||||
attention_output = self.dense_output(
|
||||
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
|
||||
)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class TFViTIntermediate(tf.keras.layers.Layer):
|
||||
def __init__(self, config: ViTConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
|
||||
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
||||
hidden_states = self.dense(inputs=hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TFViTOutput(tf.keras.layers.Layer):
|
||||
def __init__(self, config: ViTConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
||||
|
||||
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
|
||||
hidden_states = self.dense(inputs=hidden_states)
|
||||
hidden_states = self.dropout(inputs=hidden_states, training=training)
|
||||
hidden_states = hidden_states + input_tensor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TFViTLayer(tf.keras.layers.Layer):
|
||||
"""This corresponds to the Block class in the timm implementation."""
|
||||
|
||||
def __init__(self, config: ViTConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.attention = TFViTAttention(config, name="attention")
|
||||
self.intermediate = TFViTIntermediate(config, name="intermediate")
|
||||
self.vit_output = TFViTOutput(config, name="output")
|
||||
|
||||
self.layernorm_before = tf.keras.layers.LayerNormalization(
|
||||
epsilon=config.layer_norm_eps, name="layernorm_before"
|
||||
)
|
||||
self.layernorm_after = tf.keras.layers.LayerNormalization(
|
||||
epsilon=config.layer_norm_eps, name="layernorm_after"
|
||||
)
|
||||
|
||||
def call(
|
||||
self,
|
||||
hidden_states: tf.Tensor,
|
||||
head_mask: tf.Tensor,
|
||||
output_attentions: bool,
|
||||
training: bool = False,
|
||||
) -> Tuple[tf.Tensor]:
|
||||
attention_outputs = self.attention(
|
||||
# in ViT, layernorm is applied before self-attention
|
||||
input_tensor=self.layernorm_before(inputs=hidden_states),
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
training=training,
|
||||
)
|
||||
attention_output = attention_outputs[0]
|
||||
|
||||
# first residual connection
|
||||
hidden_states = attention_output + hidden_states
|
||||
|
||||
# in ViT, layernorm is also applied after self-attention
|
||||
layer_output = self.layernorm_after(inputs=hidden_states)
|
||||
|
||||
intermediate_output = self.intermediate(hidden_states=layer_output)
|
||||
|
||||
# second residual connection is done here
|
||||
layer_output = self.vit_output(
|
||||
hidden_states=intermediate_output, input_tensor=hidden_states, training=training
|
||||
)
|
||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class TFViTEncoder(tf.keras.layers.Layer):
|
||||
def __init__(self, config: ViTConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.layer = [TFViTLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
|
||||
|
||||
def call(
|
||||
self,
|
||||
hidden_states: tf.Tensor,
|
||||
head_mask: tf.Tensor,
|
||||
output_attentions: bool,
|
||||
output_hidden_states: bool,
|
||||
return_dict: bool,
|
||||
training: bool = False,
|
||||
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states=hidden_states,
|
||||
head_mask=head_mask[i],
|
||||
output_attentions=output_attentions,
|
||||
training=training,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
# Add last layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
||||
|
||||
return TFBaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
||||
)
|
||||
|
||||
|
||||
@keras_serializable
|
||||
class TFViTMainLayer(tf.keras.layers.Layer):
|
||||
config_class = ViTConfig
|
||||
|
||||
def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
|
||||
self.embeddings = TFViTEmbeddings(config, name="embeddings")
|
||||
self.encoder = TFViTEncoder(config, name="encoder")
|
||||
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
|
||||
self.pooler = TFViTPooler(config, name="pooler") if add_pooling_layer else None
|
||||
|
||||
def get_input_embeddings(self) -> tf.keras.layers.Layer:
|
||||
return self.embeddings.patch_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
"""
|
||||
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
||||
class PreTrainedModel
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def call(
|
||||
self,
|
||||
pixel_values: Optional[TFModelInputType] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
training: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=pixel_values,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
if "input_ids" in inputs:
|
||||
inputs["pixel_values"] = inputs.pop("input_ids")
|
||||
|
||||
if inputs["pixel_values"] is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
pixel_values=inputs["pixel_values"],
|
||||
interpolate_pos_encoding=inputs["interpolate_pos_encoding"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
if inputs["head_mask"] is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
inputs["head_mask"] = [None] * self.config.num_hidden_layers
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
hidden_states=embedding_output,
|
||||
head_mask=inputs["head_mask"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
sequence_output = encoder_outputs[0]
|
||||
sequence_output = self.layernorm(inputs=sequence_output)
|
||||
pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return TFBaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class TFViTPreTrainedModel(TFPreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = ViTConfig
|
||||
base_model_prefix = "vit"
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
|
||||
"""
|
||||
Dummy inputs to build the network.
|
||||
|
||||
Returns:
|
||||
:obj:`Dict[str, tf.Tensor]`: The dummy inputs.
|
||||
"""
|
||||
VISION_DUMMY_INPUTS = tf.random.uniform(
|
||||
shape=(3, self.config.num_channels, self.config.image_size, self.config.image_size), dtype=tf.float32
|
||||
)
|
||||
return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
{
|
||||
"pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
|
||||
}
|
||||
]
|
||||
)
|
||||
def serving(self, inputs):
|
||||
"""
|
||||
Method used for serving the model.
|
||||
|
||||
Args:
|
||||
inputs (:obj:`Dict[str, tf.Tensor]`):
|
||||
The input of the saved model as a dictionary of tensors.
|
||||
"""
|
||||
output = self.call(inputs)
|
||||
|
||||
return self.serving_output(output)
|
||||
|
||||
|
||||
VIT_START_DOCSTRING = r"""
|
||||
|
||||
This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the
|
||||
generic methods the library implements for all its model (such as downloading or saving, resizing the input
|
||||
embeddings, pruning heads etc.)
|
||||
|
||||
This model is also a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ subclass. Use
|
||||
it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
.. note::
|
||||
|
||||
TF 2.0 models accepts two formats as inputs:
|
||||
|
||||
- having all inputs as keyword arguments (like PyTorch models), or
|
||||
- having all inputs as a list, tuple or dict in the first positional arguments.
|
||||
|
||||
This second option is useful when using :meth:`tf.keras.Model.fit` method which currently requires having all
|
||||
the tensors in the first argument of the model call function: :obj:`model(inputs)`.
|
||||
|
||||
Args:
|
||||
config (:class:`~transformers.ViTConfig`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.TFPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
"""
|
||||
|
||||
VIT_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (:obj:`np.ndarray`, :obj:`tf.Tensor`, :obj:`List[tf.Tensor]` :obj:`Dict[str, tf.Tensor]` or :obj:`Dict[str, np.ndarray]` and each example must have the shape :obj:`(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Pixel values can be obtained using :class:`~transformers.ViTFeatureExtractor`. See
|
||||
:meth:`transformers.ViTFeatureExtractor.__call__` for details.
|
||||
|
||||
head_mask (:obj:`np.ndarray` or :obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||
tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
|
||||
config will be used instead.
|
||||
output_hidden_states (:obj:`bool`, `optional`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
|
||||
used instead.
|
||||
interpolate_pos_encoding (:obj:`bool`, `optional`):
|
||||
Whether to interpolate the pre-trained position encodings.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. This
|
||||
argument can be used in eager mode, in graph mode the value will always be set to True.
|
||||
training (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to use the model in training mode (some modules like dropout modules have different
|
||||
behaviors between training and evaluation).
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare ViT Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
VIT_START_DOCSTRING,
|
||||
)
|
||||
class TFViTModel(TFViTPreTrainedModel):
|
||||
def __init__(self, config: ViTConfig, *inputs, add_pooling_layer=True, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.vit = TFViTMainLayer(config, add_pooling_layer=add_pooling_layer, name="vit")
|
||||
|
||||
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
pixel_values: Optional[TFModelInputType] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
training: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import ViTFeatureExtractor, TFViTModel
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
|
||||
>>> model = TFViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
|
||||
|
||||
>>> inputs = feature_extractor(images=image, return_tensors="tf")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_states = outputs.last_hidden_state
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=pixel_values,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
if "input_ids" in inputs:
|
||||
inputs["pixel_values"] = inputs.pop("input_ids")
|
||||
|
||||
outputs = self.vit(
|
||||
pixel_values=inputs["pixel_values"],
|
||||
head_mask=inputs["head_mask"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
interpolate_pos_encoding=inputs["interpolate_pos_encoding"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFBaseModelOutputWithPooling(
|
||||
last_hidden_state=output.last_hidden_state,
|
||||
pooler_output=output.pooler_output,
|
||||
hidden_states=hs,
|
||||
attentions=attns,
|
||||
)
|
||||
|
||||
|
||||
class TFViTPooler(tf.keras.layers.Layer):
|
||||
def __init__(self, config: ViTConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
units=config.hidden_size,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
activation="tanh",
|
||||
name="dense",
|
||||
)
|
||||
|
||||
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense(inputs=first_token_tensor)
|
||||
|
||||
return pooled_output
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
|
||||
the [CLS] token) e.g. for ImageNet.
|
||||
""",
|
||||
VIT_START_DOCSTRING,
|
||||
)
|
||||
class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassificationLoss):
|
||||
def __init__(self, config: ViTConfig, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.vit = TFViTMainLayer(config, add_pooling_layer=False, name="vit")
|
||||
|
||||
# Classifier head
|
||||
self.classifier = tf.keras.layers.Dense(
|
||||
units=config.num_labels,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="classifier",
|
||||
)
|
||||
|
||||
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
pixel_values: Optional[TFModelInputType] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
training: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
|
||||
r"""
|
||||
labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the image classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import ViTFeatureExtractor, TFViTForImageClassification
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
|
||||
>>> model = TFViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
|
||||
|
||||
>>> inputs = feature_extractor(images=image, return_tensors="tf")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits = outputs.logits
|
||||
>>> # model predicts one of the 1000 ImageNet classes
|
||||
>>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
|
||||
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
config=self.config,
|
||||
input_ids=pixel_values,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
training=training,
|
||||
kwargs_call=kwargs,
|
||||
)
|
||||
|
||||
if "input_ids" in inputs:
|
||||
inputs["pixel_values"] = inputs.pop("input_ids")
|
||||
|
||||
outputs = self.vit(
|
||||
pixel_values=inputs["pixel_values"],
|
||||
head_mask=inputs["head_mask"],
|
||||
output_attentions=inputs["output_attentions"],
|
||||
output_hidden_states=inputs["output_hidden_states"],
|
||||
interpolate_pos_encoding=inputs["interpolate_pos_encoding"],
|
||||
return_dict=inputs["return_dict"],
|
||||
training=inputs["training"],
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
logits = self.classifier(inputs=sequence_output[:, 0, :])
|
||||
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=logits)
|
||||
|
||||
if not inputs["return_dict"]:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFSequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
|
||||
@@ -176,6 +176,9 @@ class TFAlbertPreTrainedModel:
|
||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING = None
|
||||
|
||||
|
||||
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None
|
||||
|
||||
|
||||
TF_MODEL_FOR_MASKED_LM_MAPPING = None
|
||||
|
||||
|
||||
@@ -224,6 +227,15 @@ class TFAutoModelForCausalLM:
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFAutoModelForImageClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFAutoModelForMaskedLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
@@ -1971,6 +1983,29 @@ class TFTransfoXLPreTrainedModel:
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFViTForImageClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFViTModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
class TFViTPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["tf"])
|
||||
|
||||
|
||||
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
||||
@@ -1476,6 +1476,8 @@ class ModelTesterMixin:
|
||||
tf_inputs_dict[key] = tensor
|
||||
elif key == "input_values":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
|
||||
elif key == "pixel_values":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
|
||||
else:
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32)
|
||||
|
||||
@@ -1525,6 +1527,8 @@ class ModelTesterMixin:
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor, dtype=tf.int32)
|
||||
elif key == "input_values":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
|
||||
elif key == "pixel_values":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
|
||||
else:
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32)
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ if is_tf_available():
|
||||
|
||||
from transformers import (
|
||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_MASKED_LM_MAPPING,
|
||||
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
@@ -126,7 +127,10 @@ class TFModelTesterMixin:
|
||||
elif model_class in get_values(TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
|
||||
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
elif model_class in get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
|
||||
elif model_class in [
|
||||
*get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
|
||||
*get_values(TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
|
||||
]:
|
||||
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
elif model_class in get_values(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING):
|
||||
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
@@ -460,6 +464,8 @@ class TFModelTesterMixin:
|
||||
pt_inputs_dict[name] = key
|
||||
elif name == "input_values":
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
|
||||
elif name == "pixel_values":
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
|
||||
else:
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
|
||||
|
||||
@@ -504,6 +510,8 @@ class TFModelTesterMixin:
|
||||
pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long)
|
||||
elif name == "input_values":
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
|
||||
elif name == "pixel_values":
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
|
||||
else:
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
|
||||
# need to rename encoder-decoder "inputs" for PyTorch
|
||||
@@ -605,7 +613,7 @@ class TFModelTesterMixin:
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if self.is_encoder_decoder:
|
||||
input_ids = {
|
||||
inputs = {
|
||||
"decoder_input_ids": tf.keras.Input(
|
||||
batch_shape=(2, max_input),
|
||||
name="decoder_input_ids",
|
||||
@@ -613,10 +621,22 @@ class TFModelTesterMixin:
|
||||
),
|
||||
"input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"),
|
||||
}
|
||||
# TODO: A better way to handle vision models
|
||||
elif model_class.__name__ in ["TFViTModel", "TFViTForImageClassification"]:
|
||||
inputs = tf.keras.Input(
|
||||
batch_shape=(
|
||||
3,
|
||||
self.model_tester.num_channels,
|
||||
self.model_tester.image_size,
|
||||
self.model_tester.image_size,
|
||||
),
|
||||
name="pixel_values",
|
||||
dtype="float32",
|
||||
)
|
||||
elif model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||
input_ids = tf.keras.Input(batch_shape=(4, 2, max_input), name="input_ids", dtype="int32")
|
||||
inputs = tf.keras.Input(batch_shape=(4, 2, max_input), name="input_ids", dtype="int32")
|
||||
else:
|
||||
input_ids = tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32")
|
||||
inputs = tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32")
|
||||
|
||||
# Prepare our model
|
||||
model = model_class(config)
|
||||
@@ -626,14 +646,14 @@ class TFModelTesterMixin:
|
||||
model.save_pretrained(tmpdirname, saved_model=False)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
outputs_dict = model(input_ids)
|
||||
outputs_dict = model(inputs)
|
||||
hidden_states = outputs_dict[0]
|
||||
|
||||
# Add a dense layer on top to test integration with other keras modules
|
||||
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
|
||||
|
||||
# Compile extended model
|
||||
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs])
|
||||
extended_model = tf.keras.Model(inputs=[inputs], outputs=[outputs])
|
||||
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
|
||||
|
||||
def test_keyword_and_dict_args(self):
|
||||
@@ -647,6 +667,8 @@ class TFModelTesterMixin:
|
||||
|
||||
inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
input_ids = inputs_keywords.pop("input_ids", None)
|
||||
if input_ids is None:
|
||||
input_ids = inputs_keywords.pop("pixel_values", None)
|
||||
outputs_keywords = model(input_ids, **inputs_keywords)
|
||||
output_dict = outputs_dict[0].numpy()
|
||||
output_keywords = outputs_keywords[0].numpy()
|
||||
@@ -1236,7 +1258,8 @@ class TFModelTesterMixin:
|
||||
|
||||
# Test that model correctly compute the loss with kwargs
|
||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||
input_ids = prepared_for_class.pop("input_ids")
|
||||
input_name = "input_ids" if "input_ids" in prepared_for_class else "pixel_values"
|
||||
input_ids = prepared_for_class.pop(input_name)
|
||||
|
||||
loss = model(input_ids, **prepared_for_class)[0]
|
||||
self.assertEqual(loss.shape, [loss_size])
|
||||
@@ -1255,7 +1278,7 @@ class TFModelTesterMixin:
|
||||
signature_names = list(signature.keys())
|
||||
|
||||
# Create a dictionary holding the location of the tensors in the tuple
|
||||
tuple_index_mapping = {0: "input_ids"}
|
||||
tuple_index_mapping = {0: input_name}
|
||||
for label_key in label_keys:
|
||||
label_key_index = signature_names.index(label_key)
|
||||
tuple_index_mapping[label_key_index] = label_key
|
||||
|
||||
389
tests/test_modeling_tf_vit.py
Normal file
389
tests/test_modeling_tf_vit.py
Normal file
@@ -0,0 +1,389 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the TensorFlow ViT model. """
|
||||
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import ViTConfig
|
||||
from transformers.file_utils import cached_property, is_tf_available, is_vision_available
|
||||
from transformers.testing_utils import require_tf, require_vision, slow, tooslow
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import TFViTForImageClassification, TFViTModel
|
||||
from transformers.models.vit.modeling_tf_vit import to_2tuple
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import ViTFeatureExtractor
|
||||
|
||||
|
||||
class TFViTModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
image_size=30,
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
type_sequence_label_size=10,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
labels = None
|
||||
if self.use_labels:
|
||||
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values, labels
|
||||
|
||||
def get_config(self):
|
||||
return ViTConfig(
|
||||
image_size=self.image_size,
|
||||
patch_size=self.patch_size,
|
||||
num_channels=self.num_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
model = TFViTModel(config=config)
|
||||
result = model(pixel_values, training=False)
|
||||
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
||||
image_size = to_2tuple(self.image_size)
|
||||
patch_size = to_2tuple(self.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
||||
|
||||
# Test with an image with different size than the one specified in config.
|
||||
image_size = self.image_size // 2
|
||||
pixel_values = pixel_values[:, :, :image_size, :image_size]
|
||||
result = model(pixel_values, interpolate_pos_encoding=True, training=False)
|
||||
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
||||
image_size = to_2tuple(image_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
||||
|
||||
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||
config.num_labels = self.type_sequence_label_size
|
||||
model = TFViTForImageClassification(config)
|
||||
result = model(pixel_values, labels=labels, training=False)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||
|
||||
# Test with an image with different size than the one specified in config.
|
||||
image_size = self.image_size // 2
|
||||
pixel_values = pixel_values[:, :, :image_size, :image_size]
|
||||
result = model(pixel_values, interpolate_pos_encoding=True, training=False)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values, labels = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFViTModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Here we also overwrite some of the tests of test_modeling_tf_common.py, as ViT does not use input_ids, inputs_embeds,
|
||||
attention_mask and seq_length.
|
||||
"""
|
||||
|
||||
all_model_classes = (TFViTModel, TFViTForImageClassification) if is_tf_available() else ()
|
||||
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFViTModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=ViTConfig, has_text_modality=False, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
# ViT does not use inputs_embeds
|
||||
pass
|
||||
|
||||
def test_graph_mode_with_inputs_embeds(self):
|
||||
# ViT does not use inputs_embeds
|
||||
pass
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
self.assertIsInstance(model.get_input_embeddings(), (tf.keras.layers.Layer))
|
||||
x = model.get_output_embeddings()
|
||||
self.assertTrue(x is None or isinstance(x, tf.keras.layers.Layer))
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.call)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
# overwrite from common since `encoder_seq_length` and `encoder_key_length` are calculated
|
||||
# in a different way than in text models.
|
||||
@tooslow
|
||||
def test_saved_model_creation_extended(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = True
|
||||
|
||||
if hasattr(config, "use_cache"):
|
||||
config.use_cache = True
|
||||
|
||||
# in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||
image_size = to_2tuple(self.model_tester.image_size)
|
||||
patch_size = to_2tuple(self.model_tester.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
seq_len = num_patches + 1
|
||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
num_out = len(model(class_inputs_dict))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, saved_model=True)
|
||||
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
|
||||
model = tf.keras.models.load_model(saved_model_dir)
|
||||
outputs = model(class_inputs_dict)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
output_hidden_states = outputs["encoder_hidden_states"]
|
||||
output_attentions = outputs["encoder_attentions"]
|
||||
else:
|
||||
output_hidden_states = outputs["hidden_states"]
|
||||
output_attentions = outputs["attentions"]
|
||||
|
||||
self.assertEqual(len(outputs), num_out)
|
||||
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
||||
)
|
||||
|
||||
self.assertEqual(len(output_hidden_states), expected_num_layers)
|
||||
self.assertListEqual(
|
||||
list(output_hidden_states[0].shape[-2:]),
|
||||
[seq_len, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
self.assertEqual(len(output_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(output_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
# in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||
image_size = to_2tuple(self.model_tester.image_size)
|
||||
patch_size = to_2tuple(self.model_tester.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
seq_len = num_patches + 1
|
||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
|
||||
|
||||
if hasattr(self.model_tester, "num_hidden_states_types"):
|
||||
added_hidden_states = self.model_tester.num_hidden_states_types
|
||||
elif self.is_encoder_decoder:
|
||||
added_hidden_states = 2
|
||||
else:
|
||||
added_hidden_states = 1
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
|
||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
||||
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
||||
)
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
|
||||
# ViT has a different seq_length
|
||||
image_size = to_2tuple(self.model_tester.image_size)
|
||||
patch_size = to_2tuple(self.model_tester.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
seq_length = num_patches + 1
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
def test_for_image_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
|
||||
model = TFViTModel.from_pretrained("google/vit-base-patch16-224", from_pt=True)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
return image
|
||||
|
||||
|
||||
@require_vision
|
||||
class TFViTModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_feature_extractor(self):
|
||||
return ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224") if is_vision_available() else None
|
||||
|
||||
@slow
|
||||
def test_inference_image_classification_head(self):
|
||||
model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224", from_pt=True)
|
||||
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image = prepare_img()
|
||||
inputs = feature_extractor(images=image, return_tensors="tf")
|
||||
|
||||
# forward pass
|
||||
outputs = model(**inputs)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = tf.TensorShape((1, 1000))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = tf.constant([-0.2744, 0.8215, -0.0836])
|
||||
|
||||
tf.debugging.assert_near(outputs.logits[0, :3], expected_slice, atol=1e-4)
|
||||
Reference in New Issue
Block a user