From fc1d97f29d7b98e82ae17fc5ac49229e2859bcca Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Tue, 30 Nov 2021 22:21:48 +0530 Subject: [PATCH] VisionTextDualEncoder (#13511) * init vision_text_dual_encoder * fix merge * remove extra heads * fix tests * remove VISION_TEXT_DUAL_ENCODER_PRETRAINED_CONFIG_ARCHIVE_MAP * remove archive map * fix imports * fix more imports * fix init * delete tokenizers * fix imports * clean * support clip's vision model * handle None config * begin tests * more test and few fixes * warn about newly init weights * more tests * add loss to model * remove extra classes from doc * add processor * doc and small fixes * add start docstr * update flax model * flax tests * more flax tests * doc * quality * doc and quality * fix doc * doc * remove comments * update warning * quality * fix docs * Apply suggestions from code review Co-authored-by: Patrick von Platen * replace asserts, fix imports * update imports * fix import * address some review comments * fix check * reduce tolerance * fix test * add flax integration test * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * address Sylvain's comments * fix style * add pt_flax_equivalence test in PT tests * add pt integration test * update test * use pre-trained checkpoint in examples Co-authored-by: Patrick von Platen Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- docs/source/index.rst | 3 + .../model_doc/vision_text_dual_encoder.rst | 56 ++ src/transformers/__init__.py | 7 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 1 + .../models/auto/modeling_flax_auto.py | 1 + .../models/auto/processing_auto.py | 1 + .../vision_text_dual_encoder/__init__.py | 52 ++ .../configuration_vision_text_dual_encoder.py | 129 ++++ .../modeling_flax_vision_text_dual_encoder.py | 568 ++++++++++++++++++ .../modeling_vision_text_dual_encoder.py | 519 ++++++++++++++++ .../processing_vision_text_dual_encoder.py | 185 ++++++ src/transformers/utils/dummy_flax_objects.py | 12 + src/transformers/utils/dummy_pt_objects.py | 12 + ..._modeling_flax_vision_text_dual_encoder.py | 395 ++++++++++++ .../test_modeling_vision_text_dual_encoder.py | 527 ++++++++++++++++ ...test_processor_vision_text_dual_encoder.py | 170 ++++++ utils/check_repo.py | 2 + 19 files changed, 2643 insertions(+) create mode 100644 docs/source/model_doc/vision_text_dual_encoder.rst create mode 100644 src/transformers/models/vision_text_dual_encoder/__init__.py create mode 100644 src/transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py create mode 100644 src/transformers/models/vision_text_dual_encoder/modeling_flax_vision_text_dual_encoder.py create mode 100755 src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py create mode 100644 src/transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py create mode 100644 tests/test_modeling_flax_vision_text_dual_encoder.py create mode 100644 tests/test_modeling_vision_text_dual_encoder.py create mode 100644 tests/test_processor_vision_text_dual_encoder.py diff --git a/docs/source/index.rst b/docs/source/index.rst index 0c0ed1a0ea..91ffeee47e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -511,6 +511,8 @@ Flax), PyTorch, and/or TensorFlow. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | Vision Encoder decoder | ❌ | ❌ | ✅ | ❌ | ✅ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| VisionTextDualEncoder | ❌ | ❌ | ✅ | ❌ | ✅ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | ViT | ❌ | ❌ | ✅ | ✅ | ✅ | @@ -686,6 +688,7 @@ Flax), PyTorch, and/or TensorFlow. model_doc/unispeech model_doc/unispeech_sat model_doc/visionencoderdecoder + model_doc/vision_text_dual_encoder model_doc/vit model_doc/visual_bert model_doc/wav2vec2 diff --git a/docs/source/model_doc/vision_text_dual_encoder.rst b/docs/source/model_doc/vision_text_dual_encoder.rst new file mode 100644 index 0000000000..2544a5388a --- /dev/null +++ b/docs/source/model_doc/vision_text_dual_encoder.rst @@ -0,0 +1,56 @@ +.. + Copyright 2021 The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + +VisionTextDualEncoder +----------------------------------------------------------------------------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The :class:`~transformers.VisionTextDualEncoderModel` can be used to initialize a vision-text dual encoder model with +any pretrained vision autoencoding model as the vision encoder (*e.g.* :doc:`ViT `, :doc:`BEiT `, :doc:`DeiT +`) and any pretrained text autoencoding model as the text encoder (*e.g.* :doc:`RoBERTa `, :doc:`BERT +`). Two projection layers are added on top of both the vision and text encoder to project the output embeddings +to a shared latent space. The projection layers are randomly initialized so the model should be fine-tuned on a +downstream task. This model can be used to align the vision-text embeddings using CLIP like contrastive image-text +training and then can be used for zero-shot vision tasks such image-classification or retrieval. + +In `LiT: Zero-Shot Transfer with Locked-image Text Tuning `__ it is shown how +leveraging pre-trained (locked/frozen) image and text model for contrastive learning yields significant improvment on +new zero-shot vision tasks such as image classification or retrieval. + +VisionTextDualEncoderConfig +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.VisionTextDualEncoderConfig + :members: + + +VisionTextDualEncoderProcessor +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.VisionTextDualEncoderProcessor + :members: + + +VisionTextDualEncoderModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.VisionTextDualEncoderModel + :members: forward + + +FlaxVisionTextDualEncoderModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.FlaxVisionTextDualEncoderModel + :members: __call__ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a066edab98..d46aa31234 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -297,6 +297,7 @@ _import_structure = { "UniSpeechSatConfig", ], "models.vision_encoder_decoder": ["VisionEncoderDecoderConfig"], + "models.vision_text_dual_encoder": ["VisionTextDualEncoderConfig", "VisionTextDualEncoderProcessor"], "models.visual_bert": ["VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "VisualBertConfig"], "models.vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"], "models.wav2vec2": [ @@ -1307,6 +1308,7 @@ if is_torch_available(): ] ) _import_structure["models.vision_encoder_decoder"].extend(["VisionEncoderDecoderModel"]) + _import_structure["models.vision_text_dual_encoder"].extend(["VisionTextDualEncoderModel"]) _import_structure["models.visual_bert"].extend( [ "VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -1908,6 +1910,7 @@ if is_flax_available(): ) # Flax models structure + _import_structure["models.bart"].extend( [ "FlaxBartForConditionalGeneration", @@ -2028,6 +2031,7 @@ if is_flax_available(): ) _import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"]) _import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel") + _import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"]) _import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"]) _import_structure["models.wav2vec2"].extend( ["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"] @@ -2268,6 +2272,7 @@ if TYPE_CHECKING: from .models.unispeech import UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechConfig from .models.unispeech_sat import UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechSatConfig from .models.vision_encoder_decoder import VisionEncoderDecoderConfig + from .models.vision_text_dual_encoder import VisionTextDualEncoderConfig, VisionTextDualEncoderProcessor from .models.visual_bert import VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VisualBertConfig from .models.vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig from .models.wav2vec2 import ( @@ -3111,6 +3116,7 @@ if TYPE_CHECKING: UniSpeechSatPreTrainedModel, ) from .models.vision_encoder_decoder import VisionEncoderDecoderModel + from .models.vision_text_dual_encoder import VisionTextDualEncoderModel from .models.visual_bert import ( VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST, VisualBertForMultipleChoice, @@ -3706,6 +3712,7 @@ if TYPE_CHECKING: ) from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel + from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel from .models.wav2vec2 import ( FlaxWav2Vec2ForCTC, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 451077e1b8..b180e9401f 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -101,6 +101,7 @@ from . import ( unispeech, unispeech_sat, vision_encoder_decoder, + vision_text_dual_encoder, visual_bert, vit, wav2vec2, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 2070a63994..8d1293042e 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -36,6 +36,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("trocr", "TrOCRConfig"), ("fnet", "FNetConfig"), ("segformer", "SegformerConfig"), + ("vision-text-dual-encoder", "VisionTextDualEncoderConfig"), ("gptj", "GPTJConfig"), ("layoutlmv2", "LayoutLMv2Config"), ("beit", "BeitConfig"), @@ -192,6 +193,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("trocr", "TrOCR"), ("fnet", "FNet"), ("segformer", "SegFormer"), + ("vision-text-dual-encoder", "VisionTextDualEncoder"), ("gptj", "GPT-J"), ("beit", "BEiT"), ("rembert", "RemBERT"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 403c59c67d..e262f70e64 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -32,6 +32,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("qdqbert", "QDQBertModel"), ("fnet", "FNetModel"), ("segformer", "SegformerModel"), + ("vision-text-dual-encoder", "VisionTextDualEncoderModel"), ("gptj", "GPTJModel"), ("layoutlmv2", "LayoutLMv2Model"), ("beit", "BeitModel"), diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 4df31b3ada..fd9b43205c 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -29,6 +29,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict( [ # Base model mapping ("pegasus", "FlaxPegasusModel"), + ("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"), ("distilbert", "FlaxDistilBertModel"), ("albert", "FlaxAlbertModel"), ("roberta", "FlaxRobertaModel"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index c805d994a2..349a885593 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -41,6 +41,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict( ("speech_to_text_2", "Speech2Text2Processor"), ("trocr", "TrOCRProcessor"), ("wav2vec2", "Wav2Vec2Processor"), + ("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"), ] ) diff --git a/src/transformers/models/vision_text_dual_encoder/__init__.py b/src/transformers/models/vision_text_dual_encoder/__init__.py new file mode 100644 index 0000000000..fcc856c22f --- /dev/null +++ b/src/transformers/models/vision_text_dual_encoder/__init__.py @@ -0,0 +1,52 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +# rely on isort to merge the imports +from ...file_utils import _LazyModule, is_flax_available, is_torch_available + + +_import_structure = { + "configuration_vision_text_dual_encoder": ["VisionTextDualEncoderConfig"], + "processing_vision_text_dual_encoder": ["VisionTextDualEncoderProcessor"], +} + + +if is_torch_available(): + _import_structure["modeling_vision_text_dual_encoder"] = ["VisionTextDualEncoderModel"] + + +if is_flax_available(): + _import_structure["modeling_flax_vision_text_dual_encoder"] = ["FlaxVisionTextDualEncoderModel"] + + +if TYPE_CHECKING: + from .configuration_vision_text_dual_encoder import VisionTextDualEncoderConfig + from .processing_visiotn_text_dual_encoder import VisionTextDualEncoderProcessor + + if is_torch_available(): + from .modeling_vision_text_dual_encoder import VisionTextDualEncoderModel + + if is_flax_available(): + from .modeling_vision_text_dual_encoder import FlaxVisionTextDualEncoderModel + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py b/src/transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py new file mode 100644 index 0000000000..b2223e41f5 --- /dev/null +++ b/src/transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py @@ -0,0 +1,129 @@ +# coding=utf-8 +# Copyright The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" VisionTextDualEncoder model configuration """ + +import copy + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto.configuration_auto import AutoConfig +from ..clip.configuration_clip import CLIPVisionConfig + + +logger = logging.get_logger(__name__) + + +class VisionTextDualEncoderConfig(PretrainedConfig): + r""" + :class:`~transformers.VisionTextDualEncoderConfig` is the configuration class to store the configuration of a + :class:`~transformers.VisionTextDualEncoderModel`. It is used to instantiate + :class:`~transformers.VisionTextDualEncoderModel` model according to the specified arguments, defining the text + model and vision model configs. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + Args: + text_config_dict (:obj:`dict`): + Dictionary of configuration options that defines text model config. + vision_config_dict (:obj:`dict`): + Dictionary of configuration options that defines vison model config. + projection_dim (:obj:`int`, `optional`, defaults to 512): + Dimentionality of text and vision projection layers. + logit_scale_init_value (:obj:`float`, `optional`, defaults to 2.6592): + The inital value of the `logit_scale` paramter. Default is used as per the original CLIP implementation. + kwargs (`optional`): + Dictionary of keyword arguments. + + Examples:: + + >>> from transformers import ViTConfig, BertConfig, VisionTextDualEncoderConfig, VisionTextDualEncoderModel + + >>> # Initializing a BERT and ViT configuration + >>> config_vision = ViTConfig() + >>> config_text = BertConfig() + + >>> config = VisionTextDualEncoderConfig.from_vision_text_configs(config_vision, config_text, projection_dim=512) + + >>> # Initializing a BERT and ViT model + >>> model = VisionTextDualEncoderModel(config=config) + + >>> # Accessing the model configuration + >>> config_vision = model.config.vision_config + >>> config_text = model.config.text_config + + >>> # Saving the model, including its configuration + >>> model.save_pretrained('my-model') + + >>> # loading model and config from pretrained folder + >>> vision_text_config = VisionTextDualEncoderConfig.from_pretrained('vit-bert') + >>> model = VisionTextDualEncoderModel.from_pretrained('vit-bert', config=vision_text_config) + """ + + model_type = "vision-text-dual-encoder" + is_composition = True + + def __init__(self, projection_dim=512, logit_scale_init_value=2.6592, **kwargs): + super().__init__(**kwargs) + + if "vision_config" not in kwargs: + raise ValueError("`vision_config` can not be `None`.") + + if "text_config" not in kwargs: + raise ValueError("`text_config` can not be `None`.") + + vision_config = kwargs.pop("vision_config") + text_config = kwargs.pop("text_config") + + vision_model_type = vision_config.pop("model_type") + text_model_type = text_config.pop("model_type") + + if vision_model_type == "clip": + self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config + elif vision_model_type == "clip_vision_model": + self.vision_config = CLIPVisionConfig(**vision_config) + else: + self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config) + + self.text_config = AutoConfig.for_model(text_model_type, **text_config) + + self.projection_dim = projection_dim + self.logit_scale_init_value = logit_scale_init_value + + @classmethod + def from_vision_text_configs(cls, vision_config: PretrainedConfig, text_config: PretrainedConfig, **kwargs): + r""" + Instantiate a :class:`VisionTextDualEncoderConfig` (or a derived class) from text model configuration and + vision model configuration. + + Returns: + :class:`VisionTextDualEncoderConfig`: An instance of a configuration object + """ + + return cls(vision_config=vision_config.to_dict(), text_config=text_config.to_dict(), **kwargs) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default + :meth:`~transformers.PretrainedConfig.to_dict`. + + Returns: + :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + output["vision_config"] = self.vision_config.to_dict() + output["text_config"] = self.text_config.to_dict() + output["model_type"] = self.__class__.model_type + return output diff --git a/src/transformers/models/vision_text_dual_encoder/modeling_flax_vision_text_dual_encoder.py b/src/transformers/models/vision_text_dual_encoder/modeling_flax_vision_text_dual_encoder.py new file mode 100644 index 0000000000..5c03904a73 --- /dev/null +++ b/src/transformers/models/vision_text_dual_encoder/modeling_flax_vision_text_dual_encoder.py @@ -0,0 +1,568 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Flax VisionTextDualEncoder model. """ + + +from typing import Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict + +from ...file_utils import add_start_docstrings +from ...modeling_flax_utils import FlaxPreTrainedModel, append_replace_return_docstrings, overwrite_call_docstring +from ...utils import logging +from ..auto.configuration_auto import AutoConfig +from ..auto.modeling_flax_auto import FLAX_MODEL_MAPPING, FlaxAutoModel +from ..clip.modeling_flax_clip import FlaxCLIPOutput, FlaxCLIPVisionModel +from .configuration_vision_text_dual_encoder import VisionTextDualEncoderConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "VisionTextDualEncoderConfig" + +VISION_TEXT_DUAL_ENCODER_START_DOCSTRING = r""" + This class can be used to initialize a vision-text dual encoder model with any pretrained vision autoencoding model + as the vision encoder and any pretrained text model as the text encoder. The vision and text encoders are loaded + via the :meth:`~transformers.FlaxAutoModel.from_pretrained` method. The projection layers are automatically added + to the model and should be fine-tuned on a downstream task, like contrastive image-text modeling. + + In `LiT: Zero-Shot Transfer with Locked-image Text Tuning `__ it is shown how + leveraging pre-trained (locked/frozen) image and text model for contrastive learning yields significant improvment + on new zero-shot vision tasks such as image classification or retrieval. + + After such a Vision-Text-Dual-Encoder model has been trained/fine-tuned, it can be saved/loaded just like any other + models (see the examples for more information). + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a Flax Linen `flax.linen.Module + `__ subclass. Use it as a regular Flax linen Module + and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - `Just-In-Time (JIT) compilation `__ + - `Automatic Differentiation `__ + - `Vectorization `__ + - `Parallelization `__ + + Parameters: + config (:class:`~transformers.VisionTextDualEncoderConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the + model weights. + dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`): + The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on + GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given ``dtype``. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see + :meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`. +""" + + +VISION_TEXT_DUAL_ENCODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + pixel_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + a feature extractor (e.g. if you use ViT as the encoder, you should use + :class:`~transformers.ViTFeatureExtractor`). See :meth:`transformers.ViTFeatureExtractor.__call__` for + details. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +class FlaxVisionTextDualEncoderModule(nn.Module): + config: VisionTextDualEncoderConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + vision_config = self.config.vision_config + text_config = self.config.text_config + + self.vision_embed_dim = vision_config.hidden_size + self.text_embed_dim = text_config.hidden_size + self.projection_dim = self.config.projection_dim + + vision_module = FLAX_MODEL_MAPPING.get(self.config.vision_config.__class__, FlaxCLIPVisionModel).module_class + text_module = FLAX_MODEL_MAPPING[self.config.text_config.__class__].module_class + + self.vision_model = vision_module(vision_config, dtype=self.dtype) + self.text_model = text_module(text_config, dtype=self.dtype) + + self.visual_projection = nn.Dense( + self.projection_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(0.02), + use_bias=False, + ) + self.text_projection = nn.Dense( + self.projection_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(0.02), + use_bias=False, + ) + + self.logit_scale = self.param( + "logit_scale", lambda _, shape: jnp.ones(shape) * self.config.logit_scale_init_value, [] + ) + + def __call__( + self, + input_ids=None, + pixel_values=None, + attention_mask=None, + position_ids=None, + token_type_ids=None, + deterministic: bool = True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + return_dict = return_dict if return_dict is not None else self.config.return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + image_embeds = self.visual_projection(image_embeds) + + text_embeds = text_outputs[1] + text_embeds = self.text_projection(text_embeds) + + # normalized features + image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True) + text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True) + + # cosine similarity as logits + logit_scale = jnp.exp(self.logit_scale) + logits_per_text = jnp.matmul(text_embeds, image_embeds.T) * logit_scale + logits_per_image = logits_per_text.T + + if not return_dict: + return (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + + return FlaxCLIPOutput( + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +@add_start_docstrings(VISION_TEXT_DUAL_ENCODER_START_DOCSTRING) +class FlaxVisionTextDualEncoderModel(FlaxPreTrainedModel): + config_class = VisionTextDualEncoderConfig + module_class = FlaxVisionTextDualEncoderModule + + def __init__( + self, + config: VisionTextDualEncoderConfig, + input_shape: Optional[Tuple] = None, + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + **kwargs + ): + if input_shape is None: + input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3)) + + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + # init input tensor + input_ids = jnp.zeros(input_shape[0], dtype="i4") + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0]) + token_type_ids = jnp.ones_like(input_ids) + attention_mask = jnp.ones_like(input_ids) + + pixel_values = jax.random.normal(rng, input_shape[1]) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)["params"] + + def __call__( + self, + input_ids, + pixel_values, + attention_mask=None, + position_ids=None, + token_type_ids=None, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) + + if position_ids is None: + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + if token_type_ids is None: + token_type_ids = jnp.zeros_like(input_ids) + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(input_ids, dtype="i4"), + jnp.array(pixel_values, dtype=jnp.float32), + jnp.array(attention_mask, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + jnp.array(token_type_ids, dtype="i4"), + not train, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + ) + + def get_text_features( + self, + input_ids, + attention_mask=None, + position_ids=None, + token_type_ids=None, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train=False, + ): + r""" + Args: + input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + + Returns: + text_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of text model. + """ + if position_ids is None: + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + if token_type_ids is None: + token_type_ids = jnp.zeros_like(input_ids) + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + def _get_features(module, input_ids, attention_mask, position_ids, token_type_ids, deterministic): + text_outputs = module.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + token_type_ids=token_type_ids, + deterministic=deterministic, + ) + pooled_output = text_outputs[1] + text_features = module.text_projection(pooled_output) + return text_features + + return self.module.apply( + {"params": params or self.params}, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + jnp.array(token_type_ids, dtype="i4"), + not train, + method=_get_features, + rngs=rngs, + ) + + def get_image_features( + self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train=False + ): + r""" + Args: + pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained + using :class:`~transformers.ImageFeatureExtractionMixin`. See + :meth:`transformers.ImageFeatureExtractionMixin.__call__` for details. + + Returns: + image_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The image embeddings obtained + by applying the projection layer to the pooled output of vision model. + """ + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + def _get_features(module, pixel_values, deterministic): + vision_outputs = module.vision_model(pixel_values=pixel_values, deterministic=deterministic) + pooled_output = vision_outputs[1] # pooled_output + image_features = module.visual_projection(pooled_output) + return image_features + + return self.module.apply( + {"params": params or self.params}, + jnp.array(pixel_values, dtype=jnp.float32), + not train, + method=_get_features, + rngs=rngs, + ) + + @classmethod + def from_vision_text_pretrained( + cls, + vision_model_name_or_path: str = None, + text_model_name_or_path: str = None, + *model_args, + **kwargs, + ) -> FlaxPreTrainedModel: + """ + Params: + vision_model_name_or_path (:obj: `str`, `optional`, defaults to `None`): + Information necessary to initiate the vision model. Can be either: + + - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under + a user or organization name, like ``dbmdz/bert-base-german-cased``. + - A path to a `directory` containing model weights saved using + :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. + - A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In this case, ``from_pt`` + should be set to :obj:`True` and a configuration object should be provided as ``config`` + argument. This loading path is slower than converting the PyTorch checkpoint in a Flax model + using the provided conversion scripts and loading the Flax model afterwards. + + text_model_name_or_path (:obj: `str`, `optional`): + Information necessary to initiate the text model. Can be either: + + - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under + a user or organization name, like ``dbmdz/bert-base-german-cased``. + - A path to a `directory` containing model weights saved using + :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. + - A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In this case, ``from_pt`` + should be set to :obj:`True` and a configuration object should be provided as ``config`` + argument. This loading path is slower than converting the PyTorch checkpoint in a Flax model + using the provided conversion scripts and loading the Flax model afterwards. + + model_args (remaining positional arguments, `optional`): + All remaning positional arguments will be passed to the underlying model's ``__init__`` method. + + kwargs (remaining dictionary of keyword arguments, `optional`): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + :obj:`output_attentions=True`). + + - To update the text configuration, use the prefix `text_` for each configuration parameter. + - To update the vision configuration, use the prefix `vision_` for each configuration parameter. + - To update the parent model configuration, do not use a prefix for each configuration parameter. + + Behaves differently depending on whether a :obj:`config` is provided or automatically loaded. + + Example:: + + >>> from transformers import FlaxVisionTextDualEncoderModel + >>> # initialize a model from pretrained ViT and BERT models. Note that the projection layers will be randomly initialized. + >>> model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained('bert-base-uncased', 'google/vit-base-patch16-224') + >>> # saving model after fine-tuning + >>> model.save_pretrained("./vit-bert") + >>> # load fine-tuned model + >>> model = FlaxVisionTextDualEncoderModel.from_pretrained("./vit-bert") + """ + + kwargs_vision = { + argument[len("vision_") :]: value for argument, value in kwargs.items() if argument.startswith("vision_") + } + + kwargs_text = { + argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_") + } + + # remove text, vision kwargs from kwargs + for key in kwargs_vision.keys(): + del kwargs["vision_" + key] + for key in kwargs_text.keys(): + del kwargs["text_" + key] + + # Load and initialize the text and vision model + vision_model = kwargs_vision.pop("model", None) + if vision_model is None: + if vision_model_name_or_path is None: + raise ValueError( + "If `vision_model` is not defined as an argument, a `vision_model_name_or_path` has to be defined" + ) + + if "config" not in kwargs_vision: + vision_config = AutoConfig.from_pretrained(vision_model_name_or_path) + + if vision_config.model_type == "clip": + kwargs_vision["config"] = vision_config.vision_config + vision_model = FlaxCLIPVisionModel.from_pretrained( + vision_model_name_or_path, *model_args, **kwargs_vision + ) + else: + kwargs_vision["config"] = vision_config + vision_model = FlaxAutoModel.from_pretrained(vision_model_name_or_path, *model_args, **kwargs_vision) + + text_model = kwargs_text.pop("model", None) + if text_model is None: + if text_model_name_or_path is None: + raise ValueError( + "If `text_model` is not defined as an argument, a `text_model_name_or_path` has to be defined" + ) + + if "config" not in kwargs_text: + text_config = AutoConfig.from_pretrained(text_model_name_or_path) + kwargs_text["config"] = text_config + + text_model = FlaxAutoModel.from_pretrained(text_model_name_or_path, *model_args, **kwargs_text) + + # instantiate config with corresponding kwargs + dtype = kwargs.pop("dtype", jnp.float32) + config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_model.config, text_model.config, **kwargs) + + # init model + model = cls(config, *model_args, dtype=dtype, **kwargs) + + model.params["vision_model"] = vision_model.params + model.params["text_model"] = text_model.params + + # the projection layers are always newly initialized when loading the model + # using pre-trained vision and text model. + logger.warning( + "The projection layer and logit scale weights `[('visual_projection', 'kernel'), ('text_projection', 'kernel'), ('logit_scale',)]` " + "are newly initialized. You should probably TRAIN this model on a down-stream task " + "to be able to use it for predictions and inference." + ) + + return model + + +VISION_TEXT_DUAL_ENCODER_MODEL_DOCSTRING = r""" + Returns: + + Examples:: + + >>> from PIL import Image + >>> import requests + >>> import jax + >>> from transformers import FlaxVisionTextDualEncoderModel, VisionTextDualEncoderProcessor, ViTFeatureExtractor, BertTokenizer + + >>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + >>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224") + >>> processor = VisionTextDualEncoderProcessor(feature_extractor, tokenizer) + >>> model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained("google/vit-base-patch16-224", "bert-base-uncased") + + >>> # contrastive training + >>> urls = ["http://images.cocodataset.org/val2017/000000039769.jpg", "https://farm3.staticflickr.com/2674/5850229113_4fe05d5265_z.jpg] + >>> images = [Image.open(requests.get(url, stream=True).raw) for url in urls] + >>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=images, return_tensors="np", padding=True) + >>> outputs = model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, pixel_values=inputs.pixel_values, return_loss=True) + >>> loss, logits_per_image = outputs.loss, outputs.logits_per_imag # this is the image-text similarity score + + >>> # save and load from pretrained + >>> model.save_pretrained("vit-bert") + >>> model = FlaxVisionTextDualEncoderModel.from_pretrained("vit-bert") + + >>> # inference + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = jax.nn.softmax(logits_per_image, axis=1) # we can take the softmax to get the label probabilities + +""" + +overwrite_call_docstring( + FlaxVisionTextDualEncoderModel, + VISION_TEXT_DUAL_ENCODER_INPUTS_DOCSTRING + VISION_TEXT_DUAL_ENCODER_MODEL_DOCSTRING, +) +append_replace_return_docstrings( + FlaxVisionTextDualEncoderModel, output_type=FlaxCLIPOutput, config_class=_CONFIG_FOR_DOC +) diff --git a/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py b/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py new file mode 100755 index 0000000000..43fb97403b --- /dev/null +++ b/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py @@ -0,0 +1,519 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch VisionTextDualEncoder model. """ + + +from typing import Optional + +import torch +from torch import nn + +from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from ..auto.configuration_auto import AutoConfig +from ..auto.modeling_auto import AutoModel +from ..clip.modeling_clip import CLIPOutput, CLIPVisionConfig, CLIPVisionModel +from .configuration_vision_text_dual_encoder import VisionTextDualEncoderConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "VisionTextDualEncoderConfig" + +VISION_TEXT_DUAL_ENCODER_START_DOCSTRING = r""" + This class can be used to initialize a vision-text dual encoder model with any pretrained vision autoencoding model + as the vision encoder and any pretrained text model as the text encoder. The vision and text encoders are loaded + via the :meth:`~transformers.AutoModel.from_pretrained` method. The projection layers are automatically added to + the model and should be fine-tuned on a downstream task, like contrastive image-text modeling. + + In `LiT: Zero-Shot Transfer with Locked-image Text Tuning `__ it is shown how + leveraging pre-trained (locked/frozen) image and text model for contrastive learning yields significant improvment + on new zero-shot vision tasks such as image classification or retrieval. + + After such a Vision-Text-Dual-Encoder model has been trained/fine-tuned, it can be saved/loaded just like any other + models (see the examples for more information). + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~transformers.VisionEncoderDecoderConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + + +VISION_TEXT_DUAL_ENCODER_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + +VISION_TEXT_DUAL_ENCODER_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + :class:`~transformers.CLIPFeatureExtractor`. See :meth:`transformers.CLIPFeatureExtractor.__call__` for + details. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + +VISION_TEXT_DUAL_ENCODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using :class:`~transformers.CLIPTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + pixel_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + a feature extractor (e.g. if you use ViT as the encoder, you should use + :class:`~transformers.ViTFeatureExtractor`). See :meth:`transformers.ViTFeatureExtractor.__call__` for + details. + return_loss (:obj:`bool`, `optional`): + Whether or not to return the contrastive loss. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +# Copied from transformers.models.clip.modeling_clip.contrastive_loss +def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: + return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) + + +# Copied from transformers.models.clip.modeling_clip.clip_loss +def clip_loss(similarity: torch.Tensor) -> torch.Tensor: + caption_loss = contrastive_loss(similarity) + image_loss = contrastive_loss(similarity.T) + return (caption_loss + image_loss) / 2.0 + + +@add_start_docstrings(VISION_TEXT_DUAL_ENCODER_START_DOCSTRING) +class VisionTextDualEncoderModel(PreTrainedModel): + config_class = VisionTextDualEncoderConfig + base_model_prefix = "vision_text_dual_encoder" + + def __init__( + self, + config: Optional[VisionTextDualEncoderConfig] = None, + vision_model: Optional[PreTrainedModel] = None, + text_model: Optional[PreTrainedModel] = None, + ): + + if config is None and (vision_model is None or text_model is None): + raise ValueError("Either a configuration or an vision and a text model has to be provided") + + if config is None: + config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_model.config, text_model.config) + else: + if not isinstance(config, self.config_class): + raise ValueError(f"config: {config} has to be of type {self.config_class}") + + # initialize with config + super().__init__(config) + + if vision_model is None: + if isinstance(config.vision_config, CLIPVisionConfig): + vision_model = CLIPVisionModel(config.vision_config) + else: + vision_model = AutoModel.from_config(config.vision_config) + + if text_model is None: + text_model = AutoModel.from_config(config.text_config) + + self.vision_model = vision_model + self.text_model = text_model + + # make sure that the individual model's config refers to the shared config + # so that the updates to the config will be synced + self.vision_model.config = self.config.vision_config + self.text_model.config = self.config.text_config + + self.vision_embed_dim = config.vision_config.hidden_size + self.text_embed_dim = config.text_config.hidden_size + self.projection_dim = config.projection_dim + + self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) + self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) + self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value) + + @add_start_docstrings_to_model_forward(VISION_TEXT_DUAL_ENCODER_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + text_features (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, output_dim`): The text embeddings + obtained by applying the projection layer to the pooled output of :class:`~transformers.CLIPTextModel`. + + Examples:: + + >>> from transformers import VisionTextDualEncoderModel, AutoTokenizer + + >>> model = VisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian") + >>> tokenizer = AutoTokenizer.from_pretrained("clip-italian/clip-italian") + + >>> inputs = tokenizer(["una foto di un gatto", "una foto di un cane"], padding=True, return_tensors="pt") + >>> text_features = model.get_text_features(**inputs) + """ + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + text_features = self.text_projection(pooled_output) + + return text_features + + @add_start_docstrings_to_model_forward(VISION_TEXT_DUAL_ENCODER_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + image_features (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, output_dim`): The image embeddings + obtained by applying the projection layer to the pooled output of :class:`~transformers.CLIPVisionModel`. + + Examples:: + + >>> from PIL import Image + >>> import requests + >>> from transformers import VisionTextDualEncoderModel, AutoFeatureExtractor + + >>> model = VisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian") + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = feature_extractor(images=image, return_tensors="pt") + + >>> image_features = model.get_image_features(**inputs) + """ + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] # pooled_output + image_features = self.visual_projection(pooled_output) + + return image_features + + @add_start_docstrings_to_model_forward(VISION_TEXT_DUAL_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CLIPOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + pixel_values=None, + attention_mask=None, + position_ids=None, + return_loss=None, + token_type_ids=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + + Examples:: + + >>> from PIL import Image + >>> import requests + >>> from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor, ViTFeatureExtractor, BertTokenizer + + >>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + >>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224") + >>> processor = VisionTextDualEncoderProcessor(feature_extractor, tokenizer) + >>> model = VisionTextDualEncoderModel.from_vision_text_pretrained("google/vit-base-patch16-224", "bert-base-uncased") + + >>> # contrastive training + >>> urls = ["http://images.cocodataset.org/val2017/000000039769.jpg", "https://farm3.staticflickr.com/2674/5850229113_4fe05d5265_z.jpg] + >>> images = [Image.open(requests.get(url, stream=True).raw) for url in urls] + >>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=images, return_tensors="pt", padding=True) + >>> outputs = model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, pixel_values=inputs.pixel_values, return_loss=True) + >>> loss, logits_per_image = outputs.loss, outputs.logits_per_imag # this is the image-text similarity score + + >>> # save and load from pretrained + >>> model.save_pretrained("vit-bert") + >>> model = VisionTextDualEncoderModel.from_pretrained("vit-bert") + + >>> # inference + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] # pooler_output + image_embeds = self.visual_projection(image_embeds) + + text_embeds = text_outputs[1] # pooler_output + text_embeds = self.text_projection(text_embeds) + + # normalized features + image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.T + + loss = None + if return_loss: + loss = clip_loss(logits_per_text) + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return CLIPOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + # At the moment fast initialization is not supported + # for composite models + kwargs["_fast_init"] = False + return super().from_pretrained(*args, **kwargs) + + @classmethod + def from_vision_text_pretrained( + cls, + vision_model_name_or_path: str = None, + text_model_name_or_path: str = None, + *model_args, + **kwargs, + ) -> PreTrainedModel: + """ + Params: + vision_model_name_or_path (:obj: `str`, `optional`, defaults to `None`): + Information necessary to initiate the vision model. Can be either: + + - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under + a user or organization name, like ``dbmdz/bert-base-german-cased``. + - A path to a `directory` containing model weights saved using + :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. + - A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In this case, ``from_pt`` + should be set to :obj:`True` and a configuration object should be provided as ``config`` + argument. This loading path is slower than converting the PyTorch checkpoint in a Flax model + using the provided conversion scripts and loading the Flax model afterwards. + + text_model_name_or_path (:obj: `str`, `optional`): + Information necessary to initiate the text model. Can be either: + + - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under + a user or organization name, like ``dbmdz/bert-base-german-cased``. + - A path to a `directory` containing model weights saved using + :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. + - A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In this case, ``from_pt`` + should be set to :obj:`True` and a configuration object should be provided as ``config`` + argument. This loading path is slower than converting the PyTorch checkpoint in a Flax model + using the provided conversion scripts and loading the Flax model afterwards. + + model_args (remaining positional arguments, `optional`): + All remaning positional arguments will be passed to the underlying model's ``__init__`` method. + + kwargs (remaining dictionary of keyword arguments, `optional`): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + :obj:`output_attentions=True`). + + - To update the text configuration, use the prefix `text_` for each configuration parameter. + - To update the vision configuration, use the prefix `vision_` for each configuration parameter. + - To update the parent model configuration, do not use a prefix for each configuration parameter. + + Behaves differently depending on whether a :obj:`config` is provided or automatically loaded. + + Example:: + + >>> from transformers import VisionTextDualEncoderModel + >>> # initialize a model from pretrained ViT and BERT models. Note that the projection layers will be randomly initialized. + >>> model = VisionTextDualEncoderModel.from_vision_text_pretrained('bert-base-uncased', 'google/vit-base-patch16-224') + >>> # saving model after fine-tuning + >>> model.save_pretrained("./vit-bert") + >>> # load fine-tuned model + >>> model = VisionTextDualEncoderModel.from_pretrained("./vit-bert") + """ + kwargs_vision = { + argument[len("vision_") :]: value for argument, value in kwargs.items() if argument.startswith("vision_") + } + + kwargs_text = { + argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_") + } + + # remove vision, text kwargs from kwargs + for key in kwargs_vision.keys(): + del kwargs["vision_" + key] + for key in kwargs_text.keys(): + del kwargs["text_" + key] + + # Load and initialize the vision and text model + vision_model = kwargs_vision.pop("model", None) + if vision_model is None: + if vision_model_name_or_path is None: + raise ValueError( + "If `vision_model` is not defined as an argument, a `vision_model_name_or_path` has to be defined" + ) + + if "config" not in kwargs_vision: + vision_config = AutoConfig.from_pretrained(vision_model_name_or_path) + + if vision_config.model_type == "clip": + kwargs_vision["config"] = vision_config.vision_config + vision_model = CLIPVisionModel.from_pretrained(vision_model_name_or_path, *model_args, **kwargs_vision) + # TODO: Should we use the pre-trained projection as well ? + else: + kwargs_vision["config"] = vision_config + vision_model = AutoModel.from_pretrained(vision_model_name_or_path, *model_args, **kwargs_vision) + + text_model = kwargs_text.pop("model", None) + if text_model is None: + if text_model_name_or_path is None: + raise ValueError( + "If `text_model` is not defined as an argument, a `text_model_name_or_path` has to be defined" + ) + + if "config" not in kwargs_text: + text_config = AutoConfig.from_pretrained(text_model_name_or_path) + kwargs_text["config"] = text_config + + text_model = AutoModel.from_pretrained(text_model_name_or_path, *model_args, **kwargs_text) + + # instantiate config with corresponding kwargs + config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_model.config, text_model.config, **kwargs) + + # init model + model = cls(config=config, vision_model=vision_model, text_model=text_model) + + # the projection layers are always newly initialized when loading the model + # using pre-trained vision and text model. + logger.warning( + "The projection layer and logit scale weights `['visual_projection.weight', 'text_projection.weight', 'logit_scale']` " + "are newly initialized. You should probably TRAIN this model on a down-stream task " + "to be able to use it for predictions and inference." + ) + + return model diff --git a/src/transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py b/src/transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py new file mode 100644 index 0000000000..fb32320780 --- /dev/null +++ b/src/transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py @@ -0,0 +1,185 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for VisionTextDualEncoder +""" +from typing import Union + +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers.feature_extraction_utils import FeatureExtractionMixin + +from ...tokenization_utils_base import BatchEncoding +from ..auto.feature_extraction_auto import AutoFeatureExtractor +from ..auto.tokenization_auto import AutoTokenizer + + +class VisionTextDualEncoderProcessor: + r""" + Constructs a VisionTextDualEncoder processor which wraps a vision feature extractor and a tokenizer into a single + processor. + + :class:`~transformers.VisionTextDualEncoderProcessor` offers all the functionalities of + :class:`~transformers.AutoFeatureExtractor` and :class:`~transformers.AutoTokenizer`. See the + :meth:`~transformers.VisionTextDualEncoderProcessor.__call__` and + :meth:`~transformers.VisionTextDualEncoderProcessor.decode` for more information. + + Args: + feature_extractor (:class:`~transformers.AutoFeatureExtractor`): + The feature extractor is a required input. + tokenizer (:class:`~transformers.PreTrainedTokenizer`): + The tokenizer is a required input. + """ + + def __init__( + self, feature_extractor: FeatureExtractionMixin, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + ): + if not isinstance(feature_extractor, FeatureExtractionMixin): + raise ValueError( + f"`feature_extractor` has to be of type {FeatureExtractionMixin.__class__}, but is {type(feature_extractor)}" + ) + if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + raise ValueError( + f"`tokenizer` has to be of type `PreTrainedTokenizer` or `PreTrainedTokenizerFast`, but is {type(tokenizer)}" + ) + + self.feature_extractor = feature_extractor + self.tokenizer = tokenizer + self.current_processor = self.feature_extractor + + def save_pretrained(self, save_directory): + """ + Save a VisionTextDualEncoder feature extractor object and VisionTextDualEncoder tokenizer object to the + directory ``save_directory``, so that it can be re-loaded using the + :func:`~transformers.VisionTextDualEncoderProcessor.from_pretrained` class method. + + .. note:: + + This class method is simply calling :meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` and + :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.save_pretrained`. Please refer to the + docstrings of the methods above for more information. + + Args: + save_directory (:obj:`str` or :obj:`os.PathLike`): + Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will + be created if it does not exist). + """ + + self.feature_extractor.save_pretrained(save_directory) + self.tokenizer.save_pretrained(save_directory) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + r""" + Instantiate a :class:`~transformers.VisionTextDualEncoderProcessor` from a pretrained VisionTextDualEncoder + processor. + + .. note:: + + This class method is simply calling AutoFeatureExtractor's + :meth:`~transformers.PreTrainedFeatureExtractor.from_pretrained` and AutoTokenizer's + :meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained`. Please refer to the + docstrings of the methods above for more information. + + Args: + pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): + This can be either: + + - a string, the `model id` of a pretrained feature_extractor hosted inside a model repo on + huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or + namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``. + - a path to a `directory` containing a feature extractor file saved using the + :meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` method, e.g., + ``./my_model_directory/``. + - a path or url to a saved feature extractor JSON `file`, e.g., + ``./my_model_directory/preprocessor_config.json``. + + **kwargs + Additional keyword arguments passed along to both :class:`~transformers.PreTrainedFeatureExtractor` and + :class:`~transformers.PreTrainedTokenizer` + """ + feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs) + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) + + return cls(feature_extractor=feature_extractor, tokenizer=tokenizer) + + def __call__(self, text=None, images=None, return_tensors=None, **kwargs): + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the + :obj:`text` and :obj:`kwargs` arguments to VisionTextDualEncoderTokenizer's + :meth:`~transformers.PreTrainedTokenizer.__call__` if :obj:`text` is not :obj:`None` to encode the text. To + prepare the image(s), this method forwards the :obj:`images` and :obj:`kwrags` arguments to + AutoFeatureExtractor's :meth:`~transformers.AutoFeatureExtractor.__call__` if :obj:`images` is not :obj:`None`. + Please refer to the doctsring of the above two methods for more information. + + Args: + text (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + :obj:`is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (:obj:`PIL.Image.Image`, :obj:`np.ndarray`, :obj:`torch.Tensor`, :obj:`List[PIL.Image.Image]`, :obj:`List[np.ndarray]`, :obj:`List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + + return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`): + If set, will return tensors of a particular framework. Acceptable values are: + + * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. + * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. + * :obj:`'np'`: Return NumPy :obj:`np.ndarray` objects. + * :obj:`'jax'`: Return JAX :obj:`jnp.ndarray` objects. + + Returns: + :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when :obj:`text` is not :obj:`None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + :obj:`return_attention_mask=True` or if `"attention_mask"` is in :obj:`self.model_input_names` and if + :obj:`text` is not :obj:`None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when :obj:`images` is not :obj:`None`. + """ + + if text is None and images is None: + raise ValueError("You have to specify either text or images. Both cannot be none.") + + if text is not None: + encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs) + + if images is not None: + image_features = self.feature_extractor(images, return_tensors=return_tensors, **kwargs) + + if text is not None and images is not None: + encoding["pixel_values"] = image_features.pixel_values + return encoding + elif text is not None: + return encoding + else: + return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to VisionTextDualEncoderTokenizer's + :meth:`~transformers.PreTrainedTokenizer.batch_decode`. Please refer to the docstring of this method for more + information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to VisionTextDualEncoderTokenizer's + :meth:`~transformers.PreTrainedTokenizer.decode`. Please refer to the docstring of this method for more + information. + """ + return self.tokenizer.decode(*args, **kwargs) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index d7e259137a..a07fcd0915 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -1292,6 +1292,18 @@ class FlaxVisionEncoderDecoderModel: requires_backends(self, ["flax"]) +class FlaxVisionTextDualEncoderModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + def __call__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxViTForImageClassification: def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 77cc378926..b36aa7bc20 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4849,6 +4849,18 @@ class VisionEncoderDecoderModel: requires_backends(self, ["torch"]) +class VisionTextDualEncoderModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def forward(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/test_modeling_flax_vision_text_dual_encoder.py b/tests/test_modeling_flax_vision_text_dual_encoder.py new file mode 100644 index 0000000000..06f37061b8 --- /dev/null +++ b/tests/test_modeling_flax_vision_text_dual_encoder.py @@ -0,0 +1,395 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch VisionTextDualEncoder model. """ + + +import collections +import tempfile +import unittest + +import numpy as np + +from transformers.file_utils import is_flax_available, is_torch_available, is_vision_available +from transformers.testing_utils import ( + is_pt_flax_cross_test, + require_flax, + require_torch, + require_vision, + slow, + torch_device, +) + +from .test_modeling_flax_bert import FlaxBertModelTester +from .test_modeling_flax_clip import FlaxCLIPVisionModelTester +from .test_modeling_flax_common import floats_tensor, ids_tensor, random_attention_mask +from .test_modeling_flax_vit import FlaxViTModelTester + + +if is_flax_available(): + from transformers import ( + FlaxBertModel, + FlaxCLIPVisionModel, + FlaxVisionTextDualEncoderModel, + FlaxViTModel, + VisionTextDualEncoderConfig, + VisionTextDualEncoderProcessor, + ) + from transformers.modeling_flax_pytorch_utils import ( + convert_pytorch_state_dict_to_flax, + load_flax_weights_in_pytorch_model, + ) + + +if is_torch_available(): + import torch + + from transformers import VisionTextDualEncoderModel + +if is_vision_available(): + from PIL import Image + + +# Inspired by +# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py +# From PyTorch internals +def to_2tuple(x): + if isinstance(x, collections.abc.Iterable): + return x + return (x, x) + + +@require_flax +class VisionTextDualEncoderMixin: + def get_vision_text_model(self, config, text_config): + pass + + def prepare_config_and_inputs(self): + pass + + def get_pretrained_model_and_inputs(self): + pass + + def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float): + diff = np.abs((a - b)).max() + self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") + + def check_model_from_pretrained_configs( + self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs + ): + config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config) + + model = FlaxVisionTextDualEncoderModel(config) + + output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask) + + self.assertEqual(output["text_embeds"].shape, (input_ids.shape[0], config.projection_dim)) + self.assertEqual(output["image_embeds"].shape, (pixel_values.shape[0], config.projection_dim)) + + def check_vision_text_dual_encoder_from_pretrained( + self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs + ): + + vision_model, text_model = self.get_vision_text_model(vision_config, text_config) + kwargs = {"vision_model": vision_model, "text_model": text_model} + model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained(**kwargs) + + output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask) + + self.assertEqual(output["text_embeds"].shape, (input_ids.shape[0], model.config.projection_dim)) + self.assertEqual(output["image_embeds"].shape, (pixel_values.shape[0], model.config.projection_dim)) + + def check_save_load(self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs): + vision_model, text_model = self.get_vision_text_model(vision_config, text_config) + kwargs = {"vision_model": vision_model, "text_model": text_model} + model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained(**kwargs) + + output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask) + out_1 = output[0] + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = FlaxVisionTextDualEncoderModel.from_pretrained(tmpdirname) + + after_output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask) + out_2 = after_output[0] + max_diff = np.amax(np.abs(out_2 - out_1)) + self.assertLessEqual(max_diff, 1e-3) + + def check_vision_text_output_attention( + self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs + ): + vision_model, text_model = self.get_vision_text_model(vision_config, text_config) + kwargs = {"vision_model": vision_model, "text_model": text_model} + model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained(**kwargs) + + output = model( + input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, output_attentions=True + ) + + vision_attentions = output.vision_model_output.attentions + self.assertEqual(len(vision_attentions), vision_config.num_hidden_layers) + + # in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) + image_size = to_2tuple(vision_model.config.image_size) + patch_size = to_2tuple(vision_model.config.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + seq_len = num_patches + 1 + self.assertEqual(vision_attentions[0].shape[-3:], (vision_config.num_attention_heads, seq_len, seq_len)) + + text_attentions = output.text_model_output.attentions + self.assertEqual(len(text_attentions), text_config.num_hidden_layers) + + self.assertEqual( + text_attentions[0].shape[-3:], + (text_config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]), + ) + + def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): + + pt_model.to(torch_device) + pt_model.eval() + + # prepare inputs + flax_inputs = inputs_dict + pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()} + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs).to_tuple() + + fx_outputs = fx_model(**inputs_dict).to_tuple() + self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") + for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): + self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) + + # PT -> Flax + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + fx_model_loaded = FlaxVisionTextDualEncoderModel.from_pretrained(tmpdirname, from_pt=True) + + fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple() + self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") + for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]): + self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2) + + # Flax -> PT + with tempfile.TemporaryDirectory() as tmpdirname: + fx_model.save_pretrained(tmpdirname) + pt_model_loaded = VisionTextDualEncoderModel.from_pretrained(tmpdirname, from_flax=True) + + pt_model_loaded.to(torch_device) + pt_model_loaded.eval() + + with torch.no_grad(): + pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() + + self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") + for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]): + self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2) + + def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict): + + config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config) + + pt_model = VisionTextDualEncoderModel(config) + fx_model = FlaxVisionTextDualEncoderModel(config) + + fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) + fx_model.params = fx_state + + self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict) + + def check_equivalence_flax_to_pt(self, vision_config, text_config, inputs_dict): + + config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config) + + pt_model = VisionTextDualEncoderModel(config) + fx_model = FlaxVisionTextDualEncoderModel(config) + + pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) + + self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict) + + def test_model_from_pretrained_configs(self): + inputs_dict = self.prepare_config_and_inputs() + self.check_model_from_pretrained_configs(**inputs_dict) + + def test_vision_text_dual_encoder_from_pretrained(self): + inputs_dict = self.prepare_config_and_inputs() + self.check_vision_text_dual_encoder_from_pretrained(**inputs_dict) + + def test_save_load(self): + inputs_dict = self.prepare_config_and_inputs() + self.check_save_load(**inputs_dict) + + def test_vision_text_output_attention(self): + inputs_dict = self.prepare_config_and_inputs() + self.check_vision_text_output_attention(**inputs_dict) + + @is_pt_flax_cross_test + def test_pt_flax_equivalence(self): + + config_inputs_dict = self.prepare_config_and_inputs() + vision_config = config_inputs_dict.pop("vision_config") + text_config = config_inputs_dict.pop("text_config") + + inputs_dict = config_inputs_dict + + self.check_equivalence_pt_to_flax(vision_config, text_config, inputs_dict) + self.check_equivalence_flax_to_pt(vision_config, text_config, inputs_dict) + + @slow + def test_real_model_save_load_from_pretrained(self): + model_2, inputs = self.get_pretrained_model_and_inputs() + + outputs = model_2(**inputs) + out_2 = outputs[0] + + with tempfile.TemporaryDirectory() as tmp_dirname: + model_2.save_pretrained(tmp_dirname) + model_1 = FlaxVisionTextDualEncoderModel.from_pretrained(tmp_dirname) + + after_outputs = model_1(**inputs) + out_1 = after_outputs[0] + max_diff = np.amax(np.abs(out_1 - out_2)) + self.assertLessEqual(max_diff, 1e-5) + + +@require_flax +class FlaxViTBertModelTest(VisionTextDualEncoderMixin, unittest.TestCase): + def get_pretrained_model_and_inputs(self): + model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained( + "hf-internal-testing/tiny-random-vit", + "hf-internal-testing/tiny-bert", + vision_from_pt=True, + text_from_pt=True, + ) + batch_size = 13 + pixel_values = floats_tensor( + [ + batch_size, + model.config.vision_config.num_channels, + model.config.vision_config.image_size, + model.config.vision_config.image_size, + ] + ) + input_ids = ids_tensor([batch_size, 4], model.config.text_config.vocab_size) + attention_mask = random_attention_mask([batch_size, 4]) + inputs = {"pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask} + + return model, inputs + + def get_vision_text_model(self, vision_config, text_config): + vision_model = FlaxViTModel(vision_config) + text_model = FlaxBertModel(text_config) + return vision_model, text_model + + def prepare_config_and_inputs(self): + vit_model_tester = FlaxViTModelTester(self) + bert_model_tester = FlaxBertModelTester(self) + vision_config_and_inputs = vit_model_tester.prepare_config_and_inputs() + text_config_and_inputs = bert_model_tester.prepare_config_and_inputs() + + vision_config, pixel_values = vision_config_and_inputs + + text_config, input_ids, token_type_ids, attention_mask = text_config_and_inputs + + # make sure that cross attention layers are added + return { + "text_config": text_config, + "vision_config": vision_config, + "pixel_values": pixel_values, + "attention_mask": attention_mask, + "text_config": text_config, + "input_ids": input_ids, + "token_type_ids": token_type_ids, + } + + +@require_torch +class FlaxCLIPVisionBertModelTest(VisionTextDualEncoderMixin, unittest.TestCase): + def get_pretrained_model_and_inputs(self): + model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained( + "hf-internal-testing/tiny-random-clip", + "hf-internal-testing/tiny-bert", + vision_from_pt=True, + text_from_pt=True, + ) + batch_size = 13 + pixel_values = floats_tensor( + [ + batch_size, + model.config.vision_config.num_channels, + model.config.vision_config.image_size, + model.config.vision_config.image_size, + ] + ) + input_ids = ids_tensor([batch_size, 4], model.config.text_config.vocab_size) + attention_mask = random_attention_mask([batch_size, 4]) + inputs = {"pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask} + + return model, inputs + + def get_vision_text_model(self, vision_config, text_config): + vision_model = FlaxCLIPVisionModel(vision_config) + text_model = FlaxBertModel(text_config) + return vision_model, text_model + + def prepare_config_and_inputs(self): + clip_model_tester = FlaxCLIPVisionModelTester(self) + bert_model_tester = FlaxBertModelTester(self) + vision_config_and_inputs = clip_model_tester.prepare_config_and_inputs() + text_config_and_inputs = bert_model_tester.prepare_config_and_inputs() + + vision_config, pixel_values = vision_config_and_inputs + + text_config, input_ids, token_type_ids, attention_mask = text_config_and_inputs + + # make sure that cross attention layers are added + return { + "text_config": text_config, + "vision_config": vision_config, + "pixel_values": pixel_values, + "attention_mask": attention_mask, + "text_config": text_config, + "input_ids": input_ids, + "token_type_ids": token_type_ids, + } + + +@require_flax +@require_vision +class FlaxVisionTextDualEncoderIntegrationTest(unittest.TestCase): + @slow + def test_inference(self): + model = FlaxVisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian", logit_scale_init_value=1) + processor = VisionTextDualEncoderProcessor.from_pretrained("clip-italian/clip-italian") + + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + inputs = processor( + text=["una foto di un gatto", "una foto di un cane"], images=image, padding=True, return_tensors="np" + ) + + outputs = model(**inputs) + + # verify the logits + self.assertEqual(outputs.logits_per_image.shape, (inputs.pixel_values.shape[0], inputs.input_ids.shape[0])) + self.assertEqual( + outputs.logits_per_text.shape, + (inputs.input_ids.shape[0], inputs.pixel_values.shape[0]), + ) + + expected_logits = np.array([[1.2284727, 0.3104122]]) + + self.assertTrue(np.allclose(outputs.logits_per_image, expected_logits, atol=1e-3)) diff --git a/tests/test_modeling_vision_text_dual_encoder.py b/tests/test_modeling_vision_text_dual_encoder.py new file mode 100644 index 0000000000..b9048c8786 --- /dev/null +++ b/tests/test_modeling_vision_text_dual_encoder.py @@ -0,0 +1,527 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch VisionTextDualEncoder model. """ + + +import collections +import tempfile +import unittest + +import numpy as np + +from transformers.file_utils import is_flax_available, is_torch_available, is_vision_available +from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_vision, slow, torch_device + +from .test_modeling_bert import BertModelTester +from .test_modeling_clip import CLIPVisionModelTester +from .test_modeling_common import floats_tensor, ids_tensor, random_attention_mask +from .test_modeling_deit import DeiTModelTester +from .test_modeling_roberta import RobertaModelTester +from .test_modeling_vit import ViTModelTester + + +if is_torch_available(): + import torch + + from transformers import ( + BertModel, + CLIPVisionModel, + DeiTModel, + RobertaModel, + VisionTextDualEncoderConfig, + VisionTextDualEncoderModel, + ViTModel, + ) + +if is_flax_available(): + from transformers import FlaxVisionTextDualEncoderModel + from transformers.modeling_flax_pytorch_utils import ( + convert_pytorch_state_dict_to_flax, + load_flax_weights_in_pytorch_model, + ) + +if is_vision_available(): + from PIL import Image + + from transformers import VisionTextDualEncoderProcessor + + +# Inspired by +# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py +# From PyTorch internals +def to_2tuple(x): + if isinstance(x, collections.abc.Iterable): + return x + return (x, x) + + +@require_torch +class VisionTextDualEncoderMixin: + def get_vision_text_model(self, config, text_config): + pass + + def prepare_config_and_inputs(self): + pass + + def get_pretrained_model_and_inputs(self): + pass + + def check_model_from_pretrained_configs( + self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs + ): + config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config) + + model = VisionTextDualEncoderModel(config) + model.to(torch_device) + model.eval() + + output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask) + + self.assertEqual(output["text_embeds"].shape, (input_ids.shape[0], config.projection_dim)) + self.assertEqual(output["image_embeds"].shape, (pixel_values.shape[0], config.projection_dim)) + + def check_vision_text_dual_encoder_model( + self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs + ): + vision_model, text_model = self.get_vision_text_model(vision_config, text_config) + model = VisionTextDualEncoderModel(vision_model=vision_model, text_model=text_model) + model.to(torch_device) + model.eval() + + output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask) + + self.assertEqual(output["text_embeds"].shape, (input_ids.shape[0], model.config.projection_dim)) + self.assertEqual(output["image_embeds"].shape, (pixel_values.shape[0], model.config.projection_dim)) + + def check_vision_text_dual_encoder_from_pretrained( + self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs + ): + + vision_model, text_model = self.get_vision_text_model(vision_config, text_config) + kwargs = {"vision_model": vision_model, "text_model": text_model} + model = VisionTextDualEncoderModel.from_vision_text_pretrained(**kwargs) + model.to(torch_device) + model.eval() + + output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask) + + self.assertEqual(output["text_embeds"].shape, (input_ids.shape[0], model.config.projection_dim)) + self.assertEqual(output["image_embeds"].shape, (pixel_values.shape[0], model.config.projection_dim)) + + def check_save_load(self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs): + vision_model, text_model = self.get_vision_text_model(vision_config, text_config) + model = VisionTextDualEncoderModel(vision_model=vision_model, text_model=text_model) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask) + out_1 = output[0].cpu().numpy() + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = VisionTextDualEncoderModel.from_pretrained(tmpdirname).eval() + model.to(torch_device) + + after_output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask) + out_2 = after_output[0].cpu().numpy() + max_diff = np.amax(np.abs(out_2 - out_1)) + self.assertLessEqual(max_diff, 1e-5) + + def check_vision_text_output_attention( + self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs + ): + vision_model, text_model = self.get_vision_text_model(vision_config, text_config) + model = VisionTextDualEncoderModel(vision_model=vision_model, text_model=text_model) + model.to(torch_device) + model.eval() + + output = model( + input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, output_attentions=True + ) + + vision_attentions = output.vision_model_output.attentions + self.assertEqual(len(vision_attentions), vision_config.num_hidden_layers) + + # in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) + image_size = to_2tuple(vision_model.config.image_size) + patch_size = to_2tuple(vision_model.config.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + seq_len = num_patches + 1 + self.assertEqual(vision_attentions[0].shape[-3:], (vision_config.num_attention_heads, seq_len, seq_len)) + + text_attentions = output.text_model_output.attentions + self.assertEqual(len(text_attentions), text_config.num_hidden_layers) + + self.assertEqual( + text_attentions[0].shape[-3:], + (text_config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]), + ) + + def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float): + diff = np.abs((a - b)).max() + self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") + + def check_pt_flax_equivalence(self, pt_model, fx_model, input_ids, attention_mask, pixel_values, **kwargs): + + pt_model.to(torch_device) + pt_model.eval() + + # prepare inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values} + pt_inputs = inputs_dict + flax_inputs = {k: v.numpy() for k, v in pt_inputs.items()} + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs).to_tuple() + + fx_outputs = fx_model(**flax_inputs).to_tuple() + self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") + for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): + self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) + + # PT -> Flax + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + fx_model_loaded = FlaxVisionTextDualEncoderModel.from_pretrained(tmpdirname, from_pt=True) + + fx_outputs_loaded = fx_model_loaded(**flax_inputs).to_tuple() + self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch") + for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]): + self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2) + + # Flax -> PT + with tempfile.TemporaryDirectory() as tmpdirname: + fx_model.save_pretrained(tmpdirname) + pt_model_loaded = VisionTextDualEncoderModel.from_pretrained(tmpdirname, from_flax=True) + + pt_model_loaded.to(torch_device) + pt_model_loaded.eval() + + with torch.no_grad(): + pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() + + self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch") + for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]): + self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2) + + def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict): + + config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config) + + pt_model = VisionTextDualEncoderModel(config) + fx_model = FlaxVisionTextDualEncoderModel(config) + + fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) + fx_model.params = fx_state + + self.check_pt_flax_equivalence(pt_model, fx_model, **inputs_dict) + + def check_equivalence_flax_to_pt(self, vision_config, text_config, inputs_dict): + + config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config) + + pt_model = VisionTextDualEncoderModel(config) + fx_model = FlaxVisionTextDualEncoderModel(config) + + pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) + + self.check_pt_flax_equivalence(pt_model, fx_model, **inputs_dict) + + def test_vision_text_dual_encoder_model(self): + inputs_dict = self.prepare_config_and_inputs() + self.check_vision_text_dual_encoder_model(**inputs_dict) + + def test_model_from_pretrained_configs(self): + inputs_dict = self.prepare_config_and_inputs() + self.check_model_from_pretrained_configs(**inputs_dict) + + def test_vision_text_dual_encoder_from_pretrained(self): + inputs_dict = self.prepare_config_and_inputs() + self.check_vision_text_dual_encoder_from_pretrained(**inputs_dict) + + def test_save_load(self): + inputs_dict = self.prepare_config_and_inputs() + self.check_save_load(**inputs_dict) + + def test_vision_text_output_attention(self): + inputs_dict = self.prepare_config_and_inputs() + self.check_vision_text_output_attention(**inputs_dict) + + @is_pt_flax_cross_test + def test_pt_flax_equivalence(self): + + config_inputs_dict = self.prepare_config_and_inputs() + vision_config = config_inputs_dict.pop("vision_config") + text_config = config_inputs_dict.pop("text_config") + + inputs_dict = config_inputs_dict + + self.check_equivalence_pt_to_flax(vision_config, text_config, inputs_dict) + self.check_equivalence_flax_to_pt(vision_config, text_config, inputs_dict) + + @slow + def test_real_model_save_load_from_pretrained(self): + model_2, inputs = self.get_pretrained_model_and_inputs() + model_2.to(torch_device) + + with torch.no_grad(): + outputs = model_2(**inputs) + out_2 = outputs[0].cpu().numpy() + + with tempfile.TemporaryDirectory() as tmp_dirname: + model_2.save_pretrained(tmp_dirname) + model_1 = VisionTextDualEncoderModel.from_pretrained(tmp_dirname) + model_1.to(torch_device) + + after_outputs = model_1(**inputs) + out_1 = after_outputs[0].cpu().numpy() + max_diff = np.amax(np.abs(out_1 - out_2)) + self.assertLessEqual(max_diff, 1e-5) + + +@require_torch +class ViTBertModelTest(VisionTextDualEncoderMixin, unittest.TestCase): + def get_pretrained_model_and_inputs(self): + model = VisionTextDualEncoderModel.from_vision_text_pretrained( + "hf-internal-testing/tiny-random-vit", "hf-internal-testing/tiny-bert" + ) + batch_size = 13 + pixel_values = floats_tensor( + [ + batch_size, + model.vision_model.config.num_channels, + model.vision_model.config.image_size, + model.vision_model.config.image_size, + ] + ) + input_ids = ids_tensor([batch_size, 4], model.text_model.config.vocab_size) + attention_mask = random_attention_mask([batch_size, 4]) + inputs = {"pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask} + + return model, inputs + + def get_vision_text_model(self, vision_config, text_config): + vision_model = ViTModel(vision_config).eval() + text_model = BertModel(text_config).eval() + return vision_model, text_model + + def prepare_config_and_inputs(self): + vit_model_tester = ViTModelTester(self) + bert_model_tester = BertModelTester(self) + vision_config_and_inputs = vit_model_tester.prepare_config_and_inputs() + text_config_and_inputs = bert_model_tester.prepare_config_and_inputs() + + vision_config, pixel_values, _ = vision_config_and_inputs + + ( + text_config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = text_config_and_inputs + + return { + "text_config": text_config, + "vision_config": vision_config, + "pixel_values": pixel_values, + "attention_mask": input_mask, + "text_config": text_config, + "input_ids": input_ids, + "text_token_type_ids": token_type_ids, + "text_sequence_labels": sequence_labels, + "text_token_labels": token_labels, + "text_choice_labels": choice_labels, + } + + +@require_torch +class DeiTRobertaModelTest(VisionTextDualEncoderMixin, unittest.TestCase): + def get_pretrained_model_and_inputs(self): + model = VisionTextDualEncoderModel.from_vision_text_pretrained( + "hf-internal-testing/tiny-random-deit", "hf-internal-testing/tiny-random-roberta" + ) + batch_size = 13 + pixel_values = floats_tensor( + [ + batch_size, + model.vision_model.config.num_channels, + model.vision_model.config.image_size, + model.vision_model.config.image_size, + ] + ) + input_ids = ids_tensor([batch_size, 4], model.text_model.config.vocab_size) + attention_mask = random_attention_mask([batch_size, 4]) + inputs = {"pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask} + + return model, inputs + + def check_vision_text_output_attention( + self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs + ): + vision_model, text_model = self.get_vision_text_model(vision_config, text_config) + model = VisionTextDualEncoderModel(vision_model=vision_model, text_model=text_model) + model.to(torch_device) + model.eval() + + output = model( + input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, output_attentions=True + ) + + vision_attentions = output.vision_model_output.attentions + self.assertEqual(len(vision_attentions), vision_config.num_hidden_layers) + + # in DEiT, the seq_len equals the number of patches + 2 (we add 2 for the [CLS] and distillation tokens) + image_size = to_2tuple(vision_model.config.image_size) + patch_size = to_2tuple(vision_model.config.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + seq_len = num_patches + 2 + self.assertEqual(vision_attentions[0].shape[-3:], (vision_config.num_attention_heads, seq_len, seq_len)) + + text_attentions = output.text_model_output.attentions + self.assertEqual(len(text_attentions), text_config.num_hidden_layers) + + self.assertEqual( + text_attentions[0].shape[-3:], + (text_config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]), + ) + + def get_vision_text_model(self, vision_config, text_config): + vision_model = DeiTModel(vision_config).eval() + text_model = RobertaModel(text_config).eval() + return vision_model, text_model + + def prepare_config_and_inputs(self): + vit_model_tester = DeiTModelTester(self) + bert_model_tester = RobertaModelTester(self) + vision_config_and_inputs = vit_model_tester.prepare_config_and_inputs() + text_config_and_inputs = bert_model_tester.prepare_config_and_inputs() + + vision_config, pixel_values, _ = vision_config_and_inputs + + ( + text_config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = text_config_and_inputs + + return { + "text_config": text_config, + "vision_config": vision_config, + "pixel_values": pixel_values, + "attention_mask": input_mask, + "text_config": text_config, + "input_ids": input_ids, + "text_token_type_ids": token_type_ids, + "text_sequence_labels": sequence_labels, + "text_token_labels": token_labels, + "text_choice_labels": choice_labels, + } + + # skip as DeiT is not available in Flax + def test_pt_flax_equivalence(self): + pass + + +@require_torch +class CLIPVisionBertModelTest(VisionTextDualEncoderMixin, unittest.TestCase): + def get_pretrained_model_and_inputs(self): + model = VisionTextDualEncoderModel.from_vision_text_pretrained( + "hf-internal-testing/tiny-random-clip", "hf-internal-testing/tiny-bert" + ) + batch_size = 13 + pixel_values = floats_tensor( + [ + batch_size, + model.vision_model.config.num_channels, + model.vision_model.config.image_size, + model.vision_model.config.image_size, + ] + ) + input_ids = ids_tensor([batch_size, 4], model.text_model.config.vocab_size) + attention_mask = random_attention_mask([batch_size, 4]) + inputs = {"pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask} + + return model, inputs + + def get_vision_text_model(self, vision_config, text_config): + vision_model = CLIPVisionModel(vision_config).eval() + text_model = BertModel(text_config).eval() + return vision_model, text_model + + def prepare_config_and_inputs(self): + clip_model_tester = CLIPVisionModelTester(self) + bert_model_tester = BertModelTester(self) + vision_config_and_inputs = clip_model_tester.prepare_config_and_inputs() + text_config_and_inputs = bert_model_tester.prepare_config_and_inputs() + + vision_config, pixel_values = vision_config_and_inputs + + ( + text_config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = text_config_and_inputs + + return { + "text_config": text_config, + "vision_config": vision_config, + "pixel_values": pixel_values, + "attention_mask": input_mask, + "text_config": text_config, + "input_ids": input_ids, + "text_token_type_ids": token_type_ids, + "text_sequence_labels": sequence_labels, + "text_token_labels": token_labels, + "text_choice_labels": choice_labels, + } + + +@require_vision +@require_torch +class VisionTextDualEncoderIntegrationTest(unittest.TestCase): + @slow + def test_inference(self): + model = VisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian", logit_scale_init_value=1) + processor = VisionTextDualEncoderProcessor.from_pretrained("clip-italian/clip-italian") + + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + inputs = processor( + text=["una foto di un gatto", "una foto di un cane"], images=image, padding=True, return_tensors="pt" + ) + + outputs = model(**inputs) + + # verify the logits + self.assertEqual(outputs.logits_per_image.shape, (inputs.pixel_values.shape[0], inputs.input_ids.shape[0])) + self.assertEqual( + outputs.logits_per_text.shape, + (inputs.input_ids.shape[0], inputs.pixel_values.shape[0]), + ) + + expected_logits = torch.tensor([[1.2284727, 0.3104122]]) + + self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3)) diff --git a/tests/test_processor_vision_text_dual_encoder.py b/tests/test_processor_vision_text_dual_encoder.py new file mode 100644 index 0000000000..ed23e1659a --- /dev/null +++ b/tests/test_processor_vision_text_dual_encoder.py @@ -0,0 +1,170 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import shutil +import tempfile +import unittest + +import numpy as np + +from transformers import BertTokenizerFast +from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_vision_available +from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES, BertTokenizer +from transformers.testing_utils import require_tokenizers, require_vision + + +if is_vision_available(): + from PIL import Image + + from transformers import VisionTextDualEncoderProcessor, ViTFeatureExtractor + + +@require_tokenizers +@require_vision +class VisionTextDualEncoderProcessorTest(unittest.TestCase): + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + + # fmt: off + vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "want", "##want", "##ed", "wa", "un", "runn", "##ing", ",", "low", "lowest"] + # fmt: on + self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + + feature_extractor_map = { + "do_resize": True, + "size": 18, + "do_normalize": True, + "image_mean": [0.5, 0.5, 0.5], + "image_std": [0.5, 0.5, 0.5], + } + self.feature_extractor_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME) + with open(self.feature_extractor_file, "w", encoding="utf-8") as fp: + json.dump(feature_extractor_map, fp) + + def get_tokenizer(self, **kwargs): + return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs) + + def get_feature_extractor(self, **kwargs): + return ViTFeatureExtractor.from_pretrained(self.tmpdirname, **kwargs) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def prepare_image_inputs(self): + """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, + or a list of PyTorch tensors if one specifies torchify=True. + """ + + image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] + + image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] + + return image_inputs + + def test_save_load_pretrained_default(self): + tokenizer = self.get_tokenizer() + feature_extractor = self.get_feature_extractor() + + processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + processor.save_pretrained(self.tmpdirname) + processor = VisionTextDualEncoderProcessor.from_pretrained(self.tmpdirname) + + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab()) + self.assertIsInstance(processor.tokenizer, (BertTokenizer, BertTokenizerFast)) + + self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string()) + self.assertIsInstance(processor.feature_extractor, ViTFeatureExtractor) + + def test_save_load_pretrained_additional_features(self): + processor = VisionTextDualEncoderProcessor( + tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor() + ) + processor.save_pretrained(self.tmpdirname) + + tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)") + feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0) + + processor = VisionTextDualEncoderProcessor.from_pretrained( + self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0 + ) + + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab()) + self.assertIsInstance(processor.tokenizer, (BertTokenizer, BertTokenizerFast)) + + self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) + self.assertIsInstance(processor.feature_extractor, ViTFeatureExtractor) + + def test_feature_extractor(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + image_input = self.prepare_image_inputs() + + input_feat_extract = feature_extractor(image_input, return_tensors="np") + input_processor = processor(images=image_input, return_tensors="np") + + for key in input_feat_extract.keys(): + self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) + + def test_tokenizer(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + input_str = "lower newer" + + encoded_processor = processor(text=input_str) + + encoded_tok = tokenizer(input_str) + + for key in encoded_tok.keys(): + self.assertListEqual(encoded_tok[key], encoded_processor[key]) + + def test_processor(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input) + + self.assertListEqual(list(inputs.keys()), ["input_ids", "token_type_ids", "attention_mask", "pixel_values"]) + + # test if it raises when no input is passed + with self.assertRaises(ValueError): + processor() + + def test_tokenizer_decode(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]] + + decoded_processor = processor.batch_decode(predicted_ids) + decoded_tok = tokenizer.batch_decode(predicted_ids) + + self.assertListEqual(decoded_tok, decoded_processor) diff --git a/utils/check_repo.py b/utils/check_repo.py index 9082ca4b88..997af5c618 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -94,6 +94,8 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [ "test_modeling_tf_xlm_roberta.py", "test_modeling_xlm_prophetnet.py", "test_modeling_xlm_roberta.py", + "test_modeling_vision_text_dual_encoder.py", + "test_modeling_flax_vision_text_dual_encoder.py", ] # Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and