VisionTextDualEncoder (#13511)
* init vision_text_dual_encoder * fix merge * remove extra heads * fix tests * remove VISION_TEXT_DUAL_ENCODER_PRETRAINED_CONFIG_ARCHIVE_MAP * remove archive map * fix imports * fix more imports * fix init * delete tokenizers * fix imports * clean * support clip's vision model * handle None config * begin tests * more test and few fixes * warn about newly init weights * more tests * add loss to model * remove extra classes from doc * add processor * doc and small fixes * add start docstr * update flax model * flax tests * more flax tests * doc * quality * doc and quality * fix doc * doc * remove comments * update warning * quality * fix docs * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * replace asserts, fix imports * update imports * fix import * address some review comments * fix check * reduce tolerance * fix test * add flax integration test * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * address Sylvain's comments * fix style * add pt_flax_equivalence test in PT tests * add pt integration test * update test * use pre-trained checkpoint in examples Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -511,6 +511,8 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Vision Encoder decoder | ❌ | ❌ | ✅ | ❌ | ✅ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| VisionTextDualEncoder | ❌ | ❌ | ✅ | ❌ | ✅ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| ViT | ❌ | ❌ | ✅ | ✅ | ✅ |
|
||||
@@ -686,6 +688,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
model_doc/unispeech
|
||||
model_doc/unispeech_sat
|
||||
model_doc/visionencoderdecoder
|
||||
model_doc/vision_text_dual_encoder
|
||||
model_doc/vit
|
||||
model_doc/visual_bert
|
||||
model_doc/wav2vec2
|
||||
|
||||
56
docs/source/model_doc/vision_text_dual_encoder.rst
Normal file
56
docs/source/model_doc/vision_text_dual_encoder.rst
Normal file
@@ -0,0 +1,56 @@
|
||||
..
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
VisionTextDualEncoder
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The :class:`~transformers.VisionTextDualEncoderModel` can be used to initialize a vision-text dual encoder model with
|
||||
any pretrained vision autoencoding model as the vision encoder (*e.g.* :doc:`ViT <vit>`, :doc:`BEiT <beit>`, :doc:`DeiT
|
||||
<deit>`) and any pretrained text autoencoding model as the text encoder (*e.g.* :doc:`RoBERTa <roberta>`, :doc:`BERT
|
||||
<bert>`). Two projection layers are added on top of both the vision and text encoder to project the output embeddings
|
||||
to a shared latent space. The projection layers are randomly initialized so the model should be fine-tuned on a
|
||||
downstream task. This model can be used to align the vision-text embeddings using CLIP like contrastive image-text
|
||||
training and then can be used for zero-shot vision tasks such image-classification or retrieval.
|
||||
|
||||
In `LiT: Zero-Shot Transfer with Locked-image Text Tuning <https://arxiv.org/abs/2111.07991>`__ it is shown how
|
||||
leveraging pre-trained (locked/frozen) image and text model for contrastive learning yields significant improvment on
|
||||
new zero-shot vision tasks such as image classification or retrieval.
|
||||
|
||||
VisionTextDualEncoderConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.VisionTextDualEncoderConfig
|
||||
:members:
|
||||
|
||||
|
||||
VisionTextDualEncoderProcessor
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.VisionTextDualEncoderProcessor
|
||||
:members:
|
||||
|
||||
|
||||
VisionTextDualEncoderModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.VisionTextDualEncoderModel
|
||||
:members: forward
|
||||
|
||||
|
||||
FlaxVisionTextDualEncoderModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.FlaxVisionTextDualEncoderModel
|
||||
:members: __call__
|
||||
@@ -297,6 +297,7 @@ _import_structure = {
|
||||
"UniSpeechSatConfig",
|
||||
],
|
||||
"models.vision_encoder_decoder": ["VisionEncoderDecoderConfig"],
|
||||
"models.vision_text_dual_encoder": ["VisionTextDualEncoderConfig", "VisionTextDualEncoderProcessor"],
|
||||
"models.visual_bert": ["VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "VisualBertConfig"],
|
||||
"models.vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
|
||||
"models.wav2vec2": [
|
||||
@@ -1307,6 +1308,7 @@ if is_torch_available():
|
||||
]
|
||||
)
|
||||
_import_structure["models.vision_encoder_decoder"].extend(["VisionEncoderDecoderModel"])
|
||||
_import_structure["models.vision_text_dual_encoder"].extend(["VisionTextDualEncoderModel"])
|
||||
_import_structure["models.visual_bert"].extend(
|
||||
[
|
||||
"VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@@ -1908,6 +1910,7 @@ if is_flax_available():
|
||||
)
|
||||
|
||||
# Flax models structure
|
||||
|
||||
_import_structure["models.bart"].extend(
|
||||
[
|
||||
"FlaxBartForConditionalGeneration",
|
||||
@@ -2028,6 +2031,7 @@ if is_flax_available():
|
||||
)
|
||||
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
|
||||
_import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel")
|
||||
_import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"])
|
||||
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"])
|
||||
_import_structure["models.wav2vec2"].extend(
|
||||
["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"]
|
||||
@@ -2268,6 +2272,7 @@ if TYPE_CHECKING:
|
||||
from .models.unispeech import UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechConfig
|
||||
from .models.unispeech_sat import UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechSatConfig
|
||||
from .models.vision_encoder_decoder import VisionEncoderDecoderConfig
|
||||
from .models.vision_text_dual_encoder import VisionTextDualEncoderConfig, VisionTextDualEncoderProcessor
|
||||
from .models.visual_bert import VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VisualBertConfig
|
||||
from .models.vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
|
||||
from .models.wav2vec2 import (
|
||||
@@ -3111,6 +3116,7 @@ if TYPE_CHECKING:
|
||||
UniSpeechSatPreTrainedModel,
|
||||
)
|
||||
from .models.vision_encoder_decoder import VisionEncoderDecoderModel
|
||||
from .models.vision_text_dual_encoder import VisionTextDualEncoderModel
|
||||
from .models.visual_bert import (
|
||||
VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
VisualBertForMultipleChoice,
|
||||
@@ -3706,6 +3712,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
|
||||
from .models.vision_encoder_decoder import FlaxVisionEncoderDecoderModel
|
||||
from .models.vision_text_dual_encoder import FlaxVisionTextDualEncoderModel
|
||||
from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
|
||||
from .models.wav2vec2 import (
|
||||
FlaxWav2Vec2ForCTC,
|
||||
|
||||
@@ -101,6 +101,7 @@ from . import (
|
||||
unispeech,
|
||||
unispeech_sat,
|
||||
vision_encoder_decoder,
|
||||
vision_text_dual_encoder,
|
||||
visual_bert,
|
||||
vit,
|
||||
wav2vec2,
|
||||
|
||||
@@ -36,6 +36,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("trocr", "TrOCRConfig"),
|
||||
("fnet", "FNetConfig"),
|
||||
("segformer", "SegformerConfig"),
|
||||
("vision-text-dual-encoder", "VisionTextDualEncoderConfig"),
|
||||
("gptj", "GPTJConfig"),
|
||||
("layoutlmv2", "LayoutLMv2Config"),
|
||||
("beit", "BeitConfig"),
|
||||
@@ -192,6 +193,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("trocr", "TrOCR"),
|
||||
("fnet", "FNet"),
|
||||
("segformer", "SegFormer"),
|
||||
("vision-text-dual-encoder", "VisionTextDualEncoder"),
|
||||
("gptj", "GPT-J"),
|
||||
("beit", "BEiT"),
|
||||
("rembert", "RemBERT"),
|
||||
|
||||
@@ -32,6 +32,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("qdqbert", "QDQBertModel"),
|
||||
("fnet", "FNetModel"),
|
||||
("segformer", "SegformerModel"),
|
||||
("vision-text-dual-encoder", "VisionTextDualEncoderModel"),
|
||||
("gptj", "GPTJModel"),
|
||||
("layoutlmv2", "LayoutLMv2Model"),
|
||||
("beit", "BeitModel"),
|
||||
|
||||
@@ -29,6 +29,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Base model mapping
|
||||
("pegasus", "FlaxPegasusModel"),
|
||||
("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"),
|
||||
("distilbert", "FlaxDistilBertModel"),
|
||||
("albert", "FlaxAlbertModel"),
|
||||
("roberta", "FlaxRobertaModel"),
|
||||
|
||||
@@ -41,6 +41,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("speech_to_text_2", "Speech2Text2Processor"),
|
||||
("trocr", "TrOCRProcessor"),
|
||||
("wav2vec2", "Wav2Vec2Processor"),
|
||||
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
52
src/transformers/models/vision_text_dual_encoder/__init__.py
Normal file
52
src/transformers/models/vision_text_dual_encoder/__init__.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
# rely on isort to merge the imports
|
||||
from ...file_utils import _LazyModule, is_flax_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_vision_text_dual_encoder": ["VisionTextDualEncoderConfig"],
|
||||
"processing_vision_text_dual_encoder": ["VisionTextDualEncoderProcessor"],
|
||||
}
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_vision_text_dual_encoder"] = ["VisionTextDualEncoderModel"]
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["modeling_flax_vision_text_dual_encoder"] = ["FlaxVisionTextDualEncoderModel"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_vision_text_dual_encoder import VisionTextDualEncoderConfig
|
||||
from .processing_visiotn_text_dual_encoder import VisionTextDualEncoderProcessor
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_vision_text_dual_encoder import VisionTextDualEncoderModel
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_vision_text_dual_encoder import FlaxVisionTextDualEncoderModel
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
|
||||
@@ -0,0 +1,129 @@
|
||||
# coding=utf-8
|
||||
# Copyright The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" VisionTextDualEncoder model configuration """
|
||||
|
||||
import copy
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
from ..clip.configuration_clip import CLIPVisionConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class VisionTextDualEncoderConfig(PretrainedConfig):
|
||||
r"""
|
||||
:class:`~transformers.VisionTextDualEncoderConfig` is the configuration class to store the configuration of a
|
||||
:class:`~transformers.VisionTextDualEncoderModel`. It is used to instantiate
|
||||
:class:`~transformers.VisionTextDualEncoderModel` model according to the specified arguments, defining the text
|
||||
model and vision model configs.
|
||||
|
||||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
|
||||
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
|
||||
|
||||
Args:
|
||||
text_config_dict (:obj:`dict`):
|
||||
Dictionary of configuration options that defines text model config.
|
||||
vision_config_dict (:obj:`dict`):
|
||||
Dictionary of configuration options that defines vison model config.
|
||||
projection_dim (:obj:`int`, `optional`, defaults to 512):
|
||||
Dimentionality of text and vision projection layers.
|
||||
logit_scale_init_value (:obj:`float`, `optional`, defaults to 2.6592):
|
||||
The inital value of the `logit_scale` paramter. Default is used as per the original CLIP implementation.
|
||||
kwargs (`optional`):
|
||||
Dictionary of keyword arguments.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import ViTConfig, BertConfig, VisionTextDualEncoderConfig, VisionTextDualEncoderModel
|
||||
|
||||
>>> # Initializing a BERT and ViT configuration
|
||||
>>> config_vision = ViTConfig()
|
||||
>>> config_text = BertConfig()
|
||||
|
||||
>>> config = VisionTextDualEncoderConfig.from_vision_text_configs(config_vision, config_text, projection_dim=512)
|
||||
|
||||
>>> # Initializing a BERT and ViT model
|
||||
>>> model = VisionTextDualEncoderModel(config=config)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> config_vision = model.config.vision_config
|
||||
>>> config_text = model.config.text_config
|
||||
|
||||
>>> # Saving the model, including its configuration
|
||||
>>> model.save_pretrained('my-model')
|
||||
|
||||
>>> # loading model and config from pretrained folder
|
||||
>>> vision_text_config = VisionTextDualEncoderConfig.from_pretrained('vit-bert')
|
||||
>>> model = VisionTextDualEncoderModel.from_pretrained('vit-bert', config=vision_text_config)
|
||||
"""
|
||||
|
||||
model_type = "vision-text-dual-encoder"
|
||||
is_composition = True
|
||||
|
||||
def __init__(self, projection_dim=512, logit_scale_init_value=2.6592, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if "vision_config" not in kwargs:
|
||||
raise ValueError("`vision_config` can not be `None`.")
|
||||
|
||||
if "text_config" not in kwargs:
|
||||
raise ValueError("`text_config` can not be `None`.")
|
||||
|
||||
vision_config = kwargs.pop("vision_config")
|
||||
text_config = kwargs.pop("text_config")
|
||||
|
||||
vision_model_type = vision_config.pop("model_type")
|
||||
text_model_type = text_config.pop("model_type")
|
||||
|
||||
if vision_model_type == "clip":
|
||||
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
|
||||
elif vision_model_type == "clip_vision_model":
|
||||
self.vision_config = CLIPVisionConfig(**vision_config)
|
||||
else:
|
||||
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
|
||||
|
||||
self.text_config = AutoConfig.for_model(text_model_type, **text_config)
|
||||
|
||||
self.projection_dim = projection_dim
|
||||
self.logit_scale_init_value = logit_scale_init_value
|
||||
|
||||
@classmethod
|
||||
def from_vision_text_configs(cls, vision_config: PretrainedConfig, text_config: PretrainedConfig, **kwargs):
|
||||
r"""
|
||||
Instantiate a :class:`VisionTextDualEncoderConfig` (or a derived class) from text model configuration and
|
||||
vision model configuration.
|
||||
|
||||
Returns:
|
||||
:class:`VisionTextDualEncoderConfig`: An instance of a configuration object
|
||||
"""
|
||||
|
||||
return cls(vision_config=vision_config.to_dict(), text_config=text_config.to_dict(), **kwargs)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default
|
||||
:meth:`~transformers.PretrainedConfig.to_dict`.
|
||||
|
||||
Returns:
|
||||
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["vision_config"] = self.vision_config.to_dict()
|
||||
output["text_config"] = self.text_config.to_dict()
|
||||
output["model_type"] = self.__class__.model_type
|
||||
return output
|
||||
@@ -0,0 +1,568 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Flax VisionTextDualEncoder model. """
|
||||
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...modeling_flax_utils import FlaxPreTrainedModel, append_replace_return_docstrings, overwrite_call_docstring
|
||||
from ...utils import logging
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
from ..auto.modeling_flax_auto import FLAX_MODEL_MAPPING, FlaxAutoModel
|
||||
from ..clip.modeling_flax_clip import FlaxCLIPOutput, FlaxCLIPVisionModel
|
||||
from .configuration_vision_text_dual_encoder import VisionTextDualEncoderConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "VisionTextDualEncoderConfig"
|
||||
|
||||
VISION_TEXT_DUAL_ENCODER_START_DOCSTRING = r"""
|
||||
This class can be used to initialize a vision-text dual encoder model with any pretrained vision autoencoding model
|
||||
as the vision encoder and any pretrained text model as the text encoder. The vision and text encoders are loaded
|
||||
via the :meth:`~transformers.FlaxAutoModel.from_pretrained` method. The projection layers are automatically added
|
||||
to the model and should be fine-tuned on a downstream task, like contrastive image-text modeling.
|
||||
|
||||
In `LiT: Zero-Shot Transfer with Locked-image Text Tuning <https://arxiv.org/abs/2111.07991>`__ it is shown how
|
||||
leveraging pre-trained (locked/frozen) image and text model for contrastive learning yields significant improvment
|
||||
on new zero-shot vision tasks such as image classification or retrieval.
|
||||
|
||||
After such a Vision-Text-Dual-Encoder model has been trained/fine-tuned, it can be saved/loaded just like any other
|
||||
models (see the examples for more information).
|
||||
|
||||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
||||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
|
||||
pruning heads etc.)
|
||||
|
||||
This model is also a Flax Linen `flax.linen.Module
|
||||
<https://flax.readthedocs.io/en/latest/flax.linen.html#module>`__ subclass. Use it as a regular Flax linen Module
|
||||
and refer to the Flax documentation for all matter related to general usage and behavior.
|
||||
|
||||
Finally, this model supports inherent JAX features such as:
|
||||
|
||||
- `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
|
||||
- `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
|
||||
- `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
|
||||
- `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.VisionTextDualEncoderConfig`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
|
||||
model weights.
|
||||
dtype (:obj:`jax.numpy.dtype`, `optional`, defaults to :obj:`jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of :obj:`jax.numpy.float32`, :obj:`jax.numpy.float16` (on
|
||||
GPUs) and :obj:`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given ``dtype``.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see
|
||||
:meth:`~transformers.FlaxPreTrainedModel.to_fp16` and :meth:`~transformers.FlaxPreTrainedModel.to_bf16`.
|
||||
"""
|
||||
|
||||
|
||||
VISION_TEXT_DUAL_ENCODER_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||
details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
||||
config.max_position_embeddings - 1]``.
|
||||
|
||||
`What are position IDs? <../glossary.html#position-ids>`_
|
||||
pixel_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
||||
a feature extractor (e.g. if you use ViT as the encoder, you should use
|
||||
:class:`~transformers.ViTFeatureExtractor`). See :meth:`transformers.ViTFeatureExtractor.__call__` for
|
||||
details.
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||
more detail.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
class FlaxVisionTextDualEncoderModule(nn.Module):
|
||||
config: VisionTextDualEncoderConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
vision_config = self.config.vision_config
|
||||
text_config = self.config.text_config
|
||||
|
||||
self.vision_embed_dim = vision_config.hidden_size
|
||||
self.text_embed_dim = text_config.hidden_size
|
||||
self.projection_dim = self.config.projection_dim
|
||||
|
||||
vision_module = FLAX_MODEL_MAPPING.get(self.config.vision_config.__class__, FlaxCLIPVisionModel).module_class
|
||||
text_module = FLAX_MODEL_MAPPING[self.config.text_config.__class__].module_class
|
||||
|
||||
self.vision_model = vision_module(vision_config, dtype=self.dtype)
|
||||
self.text_model = text_module(text_config, dtype=self.dtype)
|
||||
|
||||
self.visual_projection = nn.Dense(
|
||||
self.projection_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(0.02),
|
||||
use_bias=False,
|
||||
)
|
||||
self.text_projection = nn.Dense(
|
||||
self.projection_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(0.02),
|
||||
use_bias=False,
|
||||
)
|
||||
|
||||
self.logit_scale = self.param(
|
||||
"logit_scale", lambda _, shape: jnp.ones(shape) * self.config.logit_scale_init_value, []
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids=None,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
deterministic: bool = True,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
text_outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
image_embeds = vision_outputs[1]
|
||||
image_embeds = self.visual_projection(image_embeds)
|
||||
|
||||
text_embeds = text_outputs[1]
|
||||
text_embeds = self.text_projection(text_embeds)
|
||||
|
||||
# normalized features
|
||||
image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True)
|
||||
text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
logit_scale = jnp.exp(self.logit_scale)
|
||||
logits_per_text = jnp.matmul(text_embeds, image_embeds.T) * logit_scale
|
||||
logits_per_image = logits_per_text.T
|
||||
|
||||
if not return_dict:
|
||||
return (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
||||
|
||||
return FlaxCLIPOutput(
|
||||
logits_per_image=logits_per_image,
|
||||
logits_per_text=logits_per_text,
|
||||
text_embeds=text_embeds,
|
||||
image_embeds=image_embeds,
|
||||
text_model_output=text_outputs,
|
||||
vision_model_output=vision_outputs,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(VISION_TEXT_DUAL_ENCODER_START_DOCSTRING)
|
||||
class FlaxVisionTextDualEncoderModel(FlaxPreTrainedModel):
|
||||
config_class = VisionTextDualEncoderConfig
|
||||
module_class = FlaxVisionTextDualEncoderModule
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: VisionTextDualEncoderConfig,
|
||||
input_shape: Optional[Tuple] = None,
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
**kwargs
|
||||
):
|
||||
if input_shape is None:
|
||||
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
|
||||
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
# init input tensor
|
||||
input_ids = jnp.zeros(input_shape[0], dtype="i4")
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
pixel_values = jax.random.normal(rng, input_shape[1])
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)["params"]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
pixel_values,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
params: dict = None,
|
||||
dropout_rng: jax.random.PRNGKey = None,
|
||||
train: bool = False,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.zeros_like(input_ids)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(pixel_values, dtype=jnp.float32),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
not train,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
def get_text_features(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
params: dict = None,
|
||||
dropout_rng: jax.random.PRNGKey = None,
|
||||
train=False,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||
provide it.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
|
||||
for details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
|
||||
Returns:
|
||||
text_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The text embeddings obtained by
|
||||
applying the projection layer to the pooled output of text model.
|
||||
"""
|
||||
if position_ids is None:
|
||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.zeros_like(input_ids)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
def _get_features(module, input_ids, attention_mask, position_ids, token_type_ids, deterministic):
|
||||
text_outputs = module.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
deterministic=deterministic,
|
||||
)
|
||||
pooled_output = text_outputs[1]
|
||||
text_features = module.text_projection(pooled_output)
|
||||
return text_features
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
not train,
|
||||
method=_get_features,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
def get_image_features(
|
||||
self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train=False
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained
|
||||
using :class:`~transformers.ImageFeatureExtractionMixin`. See
|
||||
:meth:`transformers.ImageFeatureExtractionMixin.__call__` for details.
|
||||
|
||||
Returns:
|
||||
image_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The image embeddings obtained
|
||||
by applying the projection layer to the pooled output of vision model.
|
||||
"""
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
def _get_features(module, pixel_values, deterministic):
|
||||
vision_outputs = module.vision_model(pixel_values=pixel_values, deterministic=deterministic)
|
||||
pooled_output = vision_outputs[1] # pooled_output
|
||||
image_features = module.visual_projection(pooled_output)
|
||||
return image_features
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(pixel_values, dtype=jnp.float32),
|
||||
not train,
|
||||
method=_get_features,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_vision_text_pretrained(
|
||||
cls,
|
||||
vision_model_name_or_path: str = None,
|
||||
text_model_name_or_path: str = None,
|
||||
*model_args,
|
||||
**kwargs,
|
||||
) -> FlaxPreTrainedModel:
|
||||
"""
|
||||
Params:
|
||||
vision_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
|
||||
Information necessary to initiate the vision model. Can be either:
|
||||
|
||||
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
||||
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||
- A path to a `directory` containing model weights saved using
|
||||
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||
- A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In this case, ``from_pt``
|
||||
should be set to :obj:`True` and a configuration object should be provided as ``config``
|
||||
argument. This loading path is slower than converting the PyTorch checkpoint in a Flax model
|
||||
using the provided conversion scripts and loading the Flax model afterwards.
|
||||
|
||||
text_model_name_or_path (:obj: `str`, `optional`):
|
||||
Information necessary to initiate the text model. Can be either:
|
||||
|
||||
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
||||
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||
- A path to a `directory` containing model weights saved using
|
||||
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||
- A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In this case, ``from_pt``
|
||||
should be set to :obj:`True` and a configuration object should be provided as ``config``
|
||||
argument. This loading path is slower than converting the PyTorch checkpoint in a Flax model
|
||||
using the provided conversion scripts and loading the Flax model afterwards.
|
||||
|
||||
model_args (remaining positional arguments, `optional`):
|
||||
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
|
||||
|
||||
kwargs (remaining dictionary of keyword arguments, `optional`):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
:obj:`output_attentions=True`).
|
||||
|
||||
- To update the text configuration, use the prefix `text_` for each configuration parameter.
|
||||
- To update the vision configuration, use the prefix `vision_` for each configuration parameter.
|
||||
- To update the parent model configuration, do not use a prefix for each configuration parameter.
|
||||
|
||||
Behaves differently depending on whether a :obj:`config` is provided or automatically loaded.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import FlaxVisionTextDualEncoderModel
|
||||
>>> # initialize a model from pretrained ViT and BERT models. Note that the projection layers will be randomly initialized.
|
||||
>>> model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained('bert-base-uncased', 'google/vit-base-patch16-224')
|
||||
>>> # saving model after fine-tuning
|
||||
>>> model.save_pretrained("./vit-bert")
|
||||
>>> # load fine-tuned model
|
||||
>>> model = FlaxVisionTextDualEncoderModel.from_pretrained("./vit-bert")
|
||||
"""
|
||||
|
||||
kwargs_vision = {
|
||||
argument[len("vision_") :]: value for argument, value in kwargs.items() if argument.startswith("vision_")
|
||||
}
|
||||
|
||||
kwargs_text = {
|
||||
argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_")
|
||||
}
|
||||
|
||||
# remove text, vision kwargs from kwargs
|
||||
for key in kwargs_vision.keys():
|
||||
del kwargs["vision_" + key]
|
||||
for key in kwargs_text.keys():
|
||||
del kwargs["text_" + key]
|
||||
|
||||
# Load and initialize the text and vision model
|
||||
vision_model = kwargs_vision.pop("model", None)
|
||||
if vision_model is None:
|
||||
if vision_model_name_or_path is None:
|
||||
raise ValueError(
|
||||
"If `vision_model` is not defined as an argument, a `vision_model_name_or_path` has to be defined"
|
||||
)
|
||||
|
||||
if "config" not in kwargs_vision:
|
||||
vision_config = AutoConfig.from_pretrained(vision_model_name_or_path)
|
||||
|
||||
if vision_config.model_type == "clip":
|
||||
kwargs_vision["config"] = vision_config.vision_config
|
||||
vision_model = FlaxCLIPVisionModel.from_pretrained(
|
||||
vision_model_name_or_path, *model_args, **kwargs_vision
|
||||
)
|
||||
else:
|
||||
kwargs_vision["config"] = vision_config
|
||||
vision_model = FlaxAutoModel.from_pretrained(vision_model_name_or_path, *model_args, **kwargs_vision)
|
||||
|
||||
text_model = kwargs_text.pop("model", None)
|
||||
if text_model is None:
|
||||
if text_model_name_or_path is None:
|
||||
raise ValueError(
|
||||
"If `text_model` is not defined as an argument, a `text_model_name_or_path` has to be defined"
|
||||
)
|
||||
|
||||
if "config" not in kwargs_text:
|
||||
text_config = AutoConfig.from_pretrained(text_model_name_or_path)
|
||||
kwargs_text["config"] = text_config
|
||||
|
||||
text_model = FlaxAutoModel.from_pretrained(text_model_name_or_path, *model_args, **kwargs_text)
|
||||
|
||||
# instantiate config with corresponding kwargs
|
||||
dtype = kwargs.pop("dtype", jnp.float32)
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_model.config, text_model.config, **kwargs)
|
||||
|
||||
# init model
|
||||
model = cls(config, *model_args, dtype=dtype, **kwargs)
|
||||
|
||||
model.params["vision_model"] = vision_model.params
|
||||
model.params["text_model"] = text_model.params
|
||||
|
||||
# the projection layers are always newly initialized when loading the model
|
||||
# using pre-trained vision and text model.
|
||||
logger.warning(
|
||||
"The projection layer and logit scale weights `[('visual_projection', 'kernel'), ('text_projection', 'kernel'), ('logit_scale',)]` "
|
||||
"are newly initialized. You should probably TRAIN this model on a down-stream task "
|
||||
"to be able to use it for predictions and inference."
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
VISION_TEXT_DUAL_ENCODER_MODEL_DOCSTRING = r"""
|
||||
Returns:
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> import jax
|
||||
>>> from transformers import FlaxVisionTextDualEncoderModel, VisionTextDualEncoderProcessor, ViTFeatureExtractor, BertTokenizer
|
||||
|
||||
>>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
|
||||
>>> processor = VisionTextDualEncoderProcessor(feature_extractor, tokenizer)
|
||||
>>> model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained("google/vit-base-patch16-224", "bert-base-uncased")
|
||||
|
||||
>>> # contrastive training
|
||||
>>> urls = ["http://images.cocodataset.org/val2017/000000039769.jpg", "https://farm3.staticflickr.com/2674/5850229113_4fe05d5265_z.jpg]
|
||||
>>> images = [Image.open(requests.get(url, stream=True).raw) for url in urls]
|
||||
>>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=images, return_tensors="np", padding=True)
|
||||
>>> outputs = model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, pixel_values=inputs.pixel_values, return_loss=True)
|
||||
>>> loss, logits_per_image = outputs.loss, outputs.logits_per_imag # this is the image-text similarity score
|
||||
|
||||
>>> # save and load from pretrained
|
||||
>>> model.save_pretrained("vit-bert")
|
||||
>>> model = FlaxVisionTextDualEncoderModel.from_pretrained("vit-bert")
|
||||
|
||||
>>> # inference
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
||||
>>> probs = jax.nn.softmax(logits_per_image, axis=1) # we can take the softmax to get the label probabilities
|
||||
|
||||
"""
|
||||
|
||||
overwrite_call_docstring(
|
||||
FlaxVisionTextDualEncoderModel,
|
||||
VISION_TEXT_DUAL_ENCODER_INPUTS_DOCSTRING + VISION_TEXT_DUAL_ENCODER_MODEL_DOCSTRING,
|
||||
)
|
||||
append_replace_return_docstrings(
|
||||
FlaxVisionTextDualEncoderModel, output_type=FlaxCLIPOutput, config_class=_CONFIG_FOR_DOC
|
||||
)
|
||||
@@ -0,0 +1,519 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch VisionTextDualEncoder model. """
|
||||
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
from ..auto.configuration_auto import AutoConfig
|
||||
from ..auto.modeling_auto import AutoModel
|
||||
from ..clip.modeling_clip import CLIPOutput, CLIPVisionConfig, CLIPVisionModel
|
||||
from .configuration_vision_text_dual_encoder import VisionTextDualEncoderConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "VisionTextDualEncoderConfig"
|
||||
|
||||
VISION_TEXT_DUAL_ENCODER_START_DOCSTRING = r"""
|
||||
This class can be used to initialize a vision-text dual encoder model with any pretrained vision autoencoding model
|
||||
as the vision encoder and any pretrained text model as the text encoder. The vision and text encoders are loaded
|
||||
via the :meth:`~transformers.AutoModel.from_pretrained` method. The projection layers are automatically added to
|
||||
the model and should be fine-tuned on a downstream task, like contrastive image-text modeling.
|
||||
|
||||
In `LiT: Zero-Shot Transfer with Locked-image Text Tuning <https://arxiv.org/abs/2111.07991>`__ it is shown how
|
||||
leveraging pre-trained (locked/frozen) image and text model for contrastive learning yields significant improvment
|
||||
on new zero-shot vision tasks such as image classification or retrieval.
|
||||
|
||||
After such a Vision-Text-Dual-Encoder model has been trained/fine-tuned, it can be saved/loaded just like any other
|
||||
models (see the examples for more information).
|
||||
|
||||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
||||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
|
||||
pruning heads etc.)
|
||||
|
||||
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
|
||||
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
|
||||
general usage and behavior.
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.VisionEncoderDecoderConfig`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
|
||||
weights.
|
||||
"""
|
||||
|
||||
|
||||
VISION_TEXT_DUAL_ENCODER_TEXT_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||
details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
||||
config.max_position_embeddings - 1]``.
|
||||
|
||||
`What are position IDs? <../glossary.html#position-ids>`_
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||
more detail.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
"""
|
||||
|
||||
VISION_TEXT_DUAL_ENCODER_VISION_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
||||
:class:`~transformers.CLIPFeatureExtractor`. See :meth:`transformers.CLIPFeatureExtractor.__call__` for
|
||||
details.
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||
more detail.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
"""
|
||||
|
||||
VISION_TEXT_DUAL_ENCODER_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.CLIPTokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||
details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
||||
config.max_position_embeddings - 1]``.
|
||||
|
||||
`What are position IDs? <../glossary.html#position-ids>`_
|
||||
pixel_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
||||
a feature extractor (e.g. if you use ViT as the encoder, you should use
|
||||
:class:`~transformers.ViTFeatureExtractor`). See :meth:`transformers.ViTFeatureExtractor.__call__` for
|
||||
details.
|
||||
return_loss (:obj:`bool`, `optional`):
|
||||
Whether or not to return the contrastive loss.
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||
more detail.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.contrastive_loss
|
||||
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
|
||||
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.clip_loss
|
||||
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
|
||||
caption_loss = contrastive_loss(similarity)
|
||||
image_loss = contrastive_loss(similarity.T)
|
||||
return (caption_loss + image_loss) / 2.0
|
||||
|
||||
|
||||
@add_start_docstrings(VISION_TEXT_DUAL_ENCODER_START_DOCSTRING)
|
||||
class VisionTextDualEncoderModel(PreTrainedModel):
|
||||
config_class = VisionTextDualEncoderConfig
|
||||
base_model_prefix = "vision_text_dual_encoder"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[VisionTextDualEncoderConfig] = None,
|
||||
vision_model: Optional[PreTrainedModel] = None,
|
||||
text_model: Optional[PreTrainedModel] = None,
|
||||
):
|
||||
|
||||
if config is None and (vision_model is None or text_model is None):
|
||||
raise ValueError("Either a configuration or an vision and a text model has to be provided")
|
||||
|
||||
if config is None:
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_model.config, text_model.config)
|
||||
else:
|
||||
if not isinstance(config, self.config_class):
|
||||
raise ValueError(f"config: {config} has to be of type {self.config_class}")
|
||||
|
||||
# initialize with config
|
||||
super().__init__(config)
|
||||
|
||||
if vision_model is None:
|
||||
if isinstance(config.vision_config, CLIPVisionConfig):
|
||||
vision_model = CLIPVisionModel(config.vision_config)
|
||||
else:
|
||||
vision_model = AutoModel.from_config(config.vision_config)
|
||||
|
||||
if text_model is None:
|
||||
text_model = AutoModel.from_config(config.text_config)
|
||||
|
||||
self.vision_model = vision_model
|
||||
self.text_model = text_model
|
||||
|
||||
# make sure that the individual model's config refers to the shared config
|
||||
# so that the updates to the config will be synced
|
||||
self.vision_model.config = self.config.vision_config
|
||||
self.text_model.config = self.config.text_config
|
||||
|
||||
self.vision_embed_dim = config.vision_config.hidden_size
|
||||
self.text_embed_dim = config.text_config.hidden_size
|
||||
self.projection_dim = config.projection_dim
|
||||
|
||||
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
|
||||
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
|
||||
self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value)
|
||||
|
||||
@add_start_docstrings_to_model_forward(VISION_TEXT_DUAL_ENCODER_TEXT_INPUTS_DOCSTRING)
|
||||
def get_text_features(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
text_features (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, output_dim`): The text embeddings
|
||||
obtained by applying the projection layer to the pooled output of :class:`~transformers.CLIPTextModel`.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import VisionTextDualEncoderModel, AutoTokenizer
|
||||
|
||||
>>> model = VisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("clip-italian/clip-italian")
|
||||
|
||||
>>> inputs = tokenizer(["una foto di un gatto", "una foto di un cane"], padding=True, return_tensors="pt")
|
||||
>>> text_features = model.get_text_features(**inputs)
|
||||
"""
|
||||
text_outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
pooled_output = text_outputs[1]
|
||||
text_features = self.text_projection(pooled_output)
|
||||
|
||||
return text_features
|
||||
|
||||
@add_start_docstrings_to_model_forward(VISION_TEXT_DUAL_ENCODER_VISION_INPUTS_DOCSTRING)
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
image_features (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, output_dim`): The image embeddings
|
||||
obtained by applying the projection layer to the pooled output of :class:`~transformers.CLIPVisionModel`.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import VisionTextDualEncoderModel, AutoFeatureExtractor
|
||||
|
||||
>>> model = VisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian")
|
||||
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = feature_extractor(images=image, return_tensors="pt")
|
||||
|
||||
>>> image_features = model.get_image_features(**inputs)
|
||||
"""
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
pooled_output = vision_outputs[1] # pooled_output
|
||||
image_features = self.visual_projection(pooled_output)
|
||||
|
||||
return image_features
|
||||
|
||||
@add_start_docstrings_to_model_forward(VISION_TEXT_DUAL_ENCODER_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=CLIPOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
return_loss=None,
|
||||
token_type_ids=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor, ViTFeatureExtractor, BertTokenizer
|
||||
|
||||
>>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
|
||||
>>> processor = VisionTextDualEncoderProcessor(feature_extractor, tokenizer)
|
||||
>>> model = VisionTextDualEncoderModel.from_vision_text_pretrained("google/vit-base-patch16-224", "bert-base-uncased")
|
||||
|
||||
>>> # contrastive training
|
||||
>>> urls = ["http://images.cocodataset.org/val2017/000000039769.jpg", "https://farm3.staticflickr.com/2674/5850229113_4fe05d5265_z.jpg]
|
||||
>>> images = [Image.open(requests.get(url, stream=True).raw) for url in urls]
|
||||
>>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=images, return_tensors="pt", padding=True)
|
||||
>>> outputs = model(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, pixel_values=inputs.pixel_values, return_loss=True)
|
||||
>>> loss, logits_per_image = outputs.loss, outputs.logits_per_imag # this is the image-text similarity score
|
||||
|
||||
>>> # save and load from pretrained
|
||||
>>> model.save_pretrained("vit-bert")
|
||||
>>> model = VisionTextDualEncoderModel.from_pretrained("vit-bert")
|
||||
|
||||
>>> # inference
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
||||
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
||||
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
text_outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
image_embeds = vision_outputs[1] # pooler_output
|
||||
image_embeds = self.visual_projection(image_embeds)
|
||||
|
||||
text_embeds = text_outputs[1] # pooler_output
|
||||
text_embeds = self.text_projection(text_embeds)
|
||||
|
||||
# normalized features
|
||||
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
|
||||
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
logit_scale = self.logit_scale.exp()
|
||||
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
||||
logits_per_image = logits_per_text.T
|
||||
|
||||
loss = None
|
||||
if return_loss:
|
||||
loss = clip_loss(logits_per_text)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CLIPOutput(
|
||||
loss=loss,
|
||||
logits_per_image=logits_per_image,
|
||||
logits_per_text=logits_per_text,
|
||||
text_embeds=text_embeds,
|
||||
image_embeds=image_embeds,
|
||||
text_model_output=text_outputs,
|
||||
vision_model_output=vision_outputs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
# At the moment fast initialization is not supported
|
||||
# for composite models
|
||||
kwargs["_fast_init"] = False
|
||||
return super().from_pretrained(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_vision_text_pretrained(
|
||||
cls,
|
||||
vision_model_name_or_path: str = None,
|
||||
text_model_name_or_path: str = None,
|
||||
*model_args,
|
||||
**kwargs,
|
||||
) -> PreTrainedModel:
|
||||
"""
|
||||
Params:
|
||||
vision_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
|
||||
Information necessary to initiate the vision model. Can be either:
|
||||
|
||||
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
||||
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||
- A path to a `directory` containing model weights saved using
|
||||
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||
- A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In this case, ``from_pt``
|
||||
should be set to :obj:`True` and a configuration object should be provided as ``config``
|
||||
argument. This loading path is slower than converting the PyTorch checkpoint in a Flax model
|
||||
using the provided conversion scripts and loading the Flax model afterwards.
|
||||
|
||||
text_model_name_or_path (:obj: `str`, `optional`):
|
||||
Information necessary to initiate the text model. Can be either:
|
||||
|
||||
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
||||
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||
- A path to a `directory` containing model weights saved using
|
||||
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||
- A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In this case, ``from_pt``
|
||||
should be set to :obj:`True` and a configuration object should be provided as ``config``
|
||||
argument. This loading path is slower than converting the PyTorch checkpoint in a Flax model
|
||||
using the provided conversion scripts and loading the Flax model afterwards.
|
||||
|
||||
model_args (remaining positional arguments, `optional`):
|
||||
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
|
||||
|
||||
kwargs (remaining dictionary of keyword arguments, `optional`):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
:obj:`output_attentions=True`).
|
||||
|
||||
- To update the text configuration, use the prefix `text_` for each configuration parameter.
|
||||
- To update the vision configuration, use the prefix `vision_` for each configuration parameter.
|
||||
- To update the parent model configuration, do not use a prefix for each configuration parameter.
|
||||
|
||||
Behaves differently depending on whether a :obj:`config` is provided or automatically loaded.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import VisionTextDualEncoderModel
|
||||
>>> # initialize a model from pretrained ViT and BERT models. Note that the projection layers will be randomly initialized.
|
||||
>>> model = VisionTextDualEncoderModel.from_vision_text_pretrained('bert-base-uncased', 'google/vit-base-patch16-224')
|
||||
>>> # saving model after fine-tuning
|
||||
>>> model.save_pretrained("./vit-bert")
|
||||
>>> # load fine-tuned model
|
||||
>>> model = VisionTextDualEncoderModel.from_pretrained("./vit-bert")
|
||||
"""
|
||||
kwargs_vision = {
|
||||
argument[len("vision_") :]: value for argument, value in kwargs.items() if argument.startswith("vision_")
|
||||
}
|
||||
|
||||
kwargs_text = {
|
||||
argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_")
|
||||
}
|
||||
|
||||
# remove vision, text kwargs from kwargs
|
||||
for key in kwargs_vision.keys():
|
||||
del kwargs["vision_" + key]
|
||||
for key in kwargs_text.keys():
|
||||
del kwargs["text_" + key]
|
||||
|
||||
# Load and initialize the vision and text model
|
||||
vision_model = kwargs_vision.pop("model", None)
|
||||
if vision_model is None:
|
||||
if vision_model_name_or_path is None:
|
||||
raise ValueError(
|
||||
"If `vision_model` is not defined as an argument, a `vision_model_name_or_path` has to be defined"
|
||||
)
|
||||
|
||||
if "config" not in kwargs_vision:
|
||||
vision_config = AutoConfig.from_pretrained(vision_model_name_or_path)
|
||||
|
||||
if vision_config.model_type == "clip":
|
||||
kwargs_vision["config"] = vision_config.vision_config
|
||||
vision_model = CLIPVisionModel.from_pretrained(vision_model_name_or_path, *model_args, **kwargs_vision)
|
||||
# TODO: Should we use the pre-trained projection as well ?
|
||||
else:
|
||||
kwargs_vision["config"] = vision_config
|
||||
vision_model = AutoModel.from_pretrained(vision_model_name_or_path, *model_args, **kwargs_vision)
|
||||
|
||||
text_model = kwargs_text.pop("model", None)
|
||||
if text_model is None:
|
||||
if text_model_name_or_path is None:
|
||||
raise ValueError(
|
||||
"If `text_model` is not defined as an argument, a `text_model_name_or_path` has to be defined"
|
||||
)
|
||||
|
||||
if "config" not in kwargs_text:
|
||||
text_config = AutoConfig.from_pretrained(text_model_name_or_path)
|
||||
kwargs_text["config"] = text_config
|
||||
|
||||
text_model = AutoModel.from_pretrained(text_model_name_or_path, *model_args, **kwargs_text)
|
||||
|
||||
# instantiate config with corresponding kwargs
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_model.config, text_model.config, **kwargs)
|
||||
|
||||
# init model
|
||||
model = cls(config=config, vision_model=vision_model, text_model=text_model)
|
||||
|
||||
# the projection layers are always newly initialized when loading the model
|
||||
# using pre-trained vision and text model.
|
||||
logger.warning(
|
||||
"The projection layer and logit scale weights `['visual_projection.weight', 'text_projection.weight', 'logit_scale']` "
|
||||
"are newly initialized. You should probably TRAIN this model on a down-stream task "
|
||||
"to be able to use it for predictions and inference."
|
||||
)
|
||||
|
||||
return model
|
||||
@@ -0,0 +1,185 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Processor class for VisionTextDualEncoder
|
||||
"""
|
||||
from typing import Union
|
||||
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from transformers.feature_extraction_utils import FeatureExtractionMixin
|
||||
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from ..auto.feature_extraction_auto import AutoFeatureExtractor
|
||||
from ..auto.tokenization_auto import AutoTokenizer
|
||||
|
||||
|
||||
class VisionTextDualEncoderProcessor:
|
||||
r"""
|
||||
Constructs a VisionTextDualEncoder processor which wraps a vision feature extractor and a tokenizer into a single
|
||||
processor.
|
||||
|
||||
:class:`~transformers.VisionTextDualEncoderProcessor` offers all the functionalities of
|
||||
:class:`~transformers.AutoFeatureExtractor` and :class:`~transformers.AutoTokenizer`. See the
|
||||
:meth:`~transformers.VisionTextDualEncoderProcessor.__call__` and
|
||||
:meth:`~transformers.VisionTextDualEncoderProcessor.decode` for more information.
|
||||
|
||||
Args:
|
||||
feature_extractor (:class:`~transformers.AutoFeatureExtractor`):
|
||||
The feature extractor is a required input.
|
||||
tokenizer (:class:`~transformers.PreTrainedTokenizer`):
|
||||
The tokenizer is a required input.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, feature_extractor: FeatureExtractionMixin, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
||||
):
|
||||
if not isinstance(feature_extractor, FeatureExtractionMixin):
|
||||
raise ValueError(
|
||||
f"`feature_extractor` has to be of type {FeatureExtractionMixin.__class__}, but is {type(feature_extractor)}"
|
||||
)
|
||||
if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
|
||||
raise ValueError(
|
||||
f"`tokenizer` has to be of type `PreTrainedTokenizer` or `PreTrainedTokenizerFast`, but is {type(tokenizer)}"
|
||||
)
|
||||
|
||||
self.feature_extractor = feature_extractor
|
||||
self.tokenizer = tokenizer
|
||||
self.current_processor = self.feature_extractor
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
"""
|
||||
Save a VisionTextDualEncoder feature extractor object and VisionTextDualEncoder tokenizer object to the
|
||||
directory ``save_directory``, so that it can be re-loaded using the
|
||||
:func:`~transformers.VisionTextDualEncoderProcessor.from_pretrained` class method.
|
||||
|
||||
.. note::
|
||||
|
||||
This class method is simply calling :meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` and
|
||||
:meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.save_pretrained`. Please refer to the
|
||||
docstrings of the methods above for more information.
|
||||
|
||||
Args:
|
||||
save_directory (:obj:`str` or :obj:`os.PathLike`):
|
||||
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
|
||||
be created if it does not exist).
|
||||
"""
|
||||
|
||||
self.feature_extractor.save_pretrained(save_directory)
|
||||
self.tokenizer.save_pretrained(save_directory)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a :class:`~transformers.VisionTextDualEncoderProcessor` from a pretrained VisionTextDualEncoder
|
||||
processor.
|
||||
|
||||
.. note::
|
||||
|
||||
This class method is simply calling AutoFeatureExtractor's
|
||||
:meth:`~transformers.PreTrainedFeatureExtractor.from_pretrained` and AutoTokenizer's
|
||||
:meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained`. Please refer to the
|
||||
docstrings of the methods above for more information.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
This can be either:
|
||||
|
||||
- a string, the `model id` of a pretrained feature_extractor hosted inside a model repo on
|
||||
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
|
||||
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||
- a path to a `directory` containing a feature extractor file saved using the
|
||||
:meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` method, e.g.,
|
||||
``./my_model_directory/``.
|
||||
- a path or url to a saved feature extractor JSON `file`, e.g.,
|
||||
``./my_model_directory/preprocessor_config.json``.
|
||||
|
||||
**kwargs
|
||||
Additional keyword arguments passed along to both :class:`~transformers.PreTrainedFeatureExtractor` and
|
||||
:class:`~transformers.PreTrainedTokenizer`
|
||||
"""
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
||||
|
||||
def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the
|
||||
:obj:`text` and :obj:`kwargs` arguments to VisionTextDualEncoderTokenizer's
|
||||
:meth:`~transformers.PreTrainedTokenizer.__call__` if :obj:`text` is not :obj:`None` to encode the text. To
|
||||
prepare the image(s), this method forwards the :obj:`images` and :obj:`kwrags` arguments to
|
||||
AutoFeatureExtractor's :meth:`~transformers.AutoFeatureExtractor.__call__` if :obj:`images` is not :obj:`None`.
|
||||
Please refer to the doctsring of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
text (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
:obj:`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
images (:obj:`PIL.Image.Image`, :obj:`np.ndarray`, :obj:`torch.Tensor`, :obj:`List[PIL.Image.Image]`, :obj:`List[np.ndarray]`, :obj:`List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
|
||||
number of channels, H and W are image height and width.
|
||||
|
||||
return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
|
||||
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
||||
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
||||
* :obj:`'np'`: Return NumPy :obj:`np.ndarray` objects.
|
||||
* :obj:`'jax'`: Return JAX :obj:`jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when :obj:`text` is not :obj:`None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
:obj:`return_attention_mask=True` or if `"attention_mask"` is in :obj:`self.model_input_names` and if
|
||||
:obj:`text` is not :obj:`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when :obj:`images` is not :obj:`None`.
|
||||
"""
|
||||
|
||||
if text is None and images is None:
|
||||
raise ValueError("You have to specify either text or images. Both cannot be none.")
|
||||
|
||||
if text is not None:
|
||||
encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
|
||||
|
||||
if images is not None:
|
||||
image_features = self.feature_extractor(images, return_tensors=return_tensors, **kwargs)
|
||||
|
||||
if text is not None and images is not None:
|
||||
encoding["pixel_values"] = image_features.pixel_values
|
||||
return encoding
|
||||
elif text is not None:
|
||||
return encoding
|
||||
else:
|
||||
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to VisionTextDualEncoderTokenizer's
|
||||
:meth:`~transformers.PreTrainedTokenizer.batch_decode`. Please refer to the docstring of this method for more
|
||||
information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to VisionTextDualEncoderTokenizer's
|
||||
:meth:`~transformers.PreTrainedTokenizer.decode`. Please refer to the docstring of this method for more
|
||||
information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
@@ -1292,6 +1292,18 @@ class FlaxVisionEncoderDecoderModel:
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxVisionTextDualEncoderModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["flax"])
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxViTForImageClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
@@ -4849,6 +4849,18 @@ class VisionEncoderDecoderModel:
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class VisionTextDualEncoderModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
||||
395
tests/test_modeling_flax_vision_text_dual_encoder.py
Normal file
395
tests/test_modeling_flax_vision_text_dual_encoder.py
Normal file
@@ -0,0 +1,395 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch VisionTextDualEncoder model. """
|
||||
|
||||
|
||||
import collections
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.file_utils import is_flax_available, is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import (
|
||||
is_pt_flax_cross_test,
|
||||
require_flax,
|
||||
require_torch,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .test_modeling_flax_bert import FlaxBertModelTester
|
||||
from .test_modeling_flax_clip import FlaxCLIPVisionModelTester
|
||||
from .test_modeling_flax_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
from .test_modeling_flax_vit import FlaxViTModelTester
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
from transformers import (
|
||||
FlaxBertModel,
|
||||
FlaxCLIPVisionModel,
|
||||
FlaxVisionTextDualEncoderModel,
|
||||
FlaxViTModel,
|
||||
VisionTextDualEncoderConfig,
|
||||
VisionTextDualEncoderProcessor,
|
||||
)
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import VisionTextDualEncoderModel
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
# Inspired by
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
|
||||
# From PyTorch internals
|
||||
def to_2tuple(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return (x, x)
|
||||
|
||||
|
||||
@require_flax
|
||||
class VisionTextDualEncoderMixin:
|
||||
def get_vision_text_model(self, config, text_config):
|
||||
pass
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pass
|
||||
|
||||
def get_pretrained_model_and_inputs(self):
|
||||
pass
|
||||
|
||||
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
||||
diff = np.abs((a - b)).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
|
||||
def check_model_from_pretrained_configs(
|
||||
self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs
|
||||
):
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
|
||||
|
||||
model = FlaxVisionTextDualEncoderModel(config)
|
||||
|
||||
output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask)
|
||||
|
||||
self.assertEqual(output["text_embeds"].shape, (input_ids.shape[0], config.projection_dim))
|
||||
self.assertEqual(output["image_embeds"].shape, (pixel_values.shape[0], config.projection_dim))
|
||||
|
||||
def check_vision_text_dual_encoder_from_pretrained(
|
||||
self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs
|
||||
):
|
||||
|
||||
vision_model, text_model = self.get_vision_text_model(vision_config, text_config)
|
||||
kwargs = {"vision_model": vision_model, "text_model": text_model}
|
||||
model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained(**kwargs)
|
||||
|
||||
output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask)
|
||||
|
||||
self.assertEqual(output["text_embeds"].shape, (input_ids.shape[0], model.config.projection_dim))
|
||||
self.assertEqual(output["image_embeds"].shape, (pixel_values.shape[0], model.config.projection_dim))
|
||||
|
||||
def check_save_load(self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs):
|
||||
vision_model, text_model = self.get_vision_text_model(vision_config, text_config)
|
||||
kwargs = {"vision_model": vision_model, "text_model": text_model}
|
||||
model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained(**kwargs)
|
||||
|
||||
output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask)
|
||||
out_1 = output[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = FlaxVisionTextDualEncoderModel.from_pretrained(tmpdirname)
|
||||
|
||||
after_output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask)
|
||||
out_2 = after_output[0]
|
||||
max_diff = np.amax(np.abs(out_2 - out_1))
|
||||
self.assertLessEqual(max_diff, 1e-3)
|
||||
|
||||
def check_vision_text_output_attention(
|
||||
self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs
|
||||
):
|
||||
vision_model, text_model = self.get_vision_text_model(vision_config, text_config)
|
||||
kwargs = {"vision_model": vision_model, "text_model": text_model}
|
||||
model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained(**kwargs)
|
||||
|
||||
output = model(
|
||||
input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, output_attentions=True
|
||||
)
|
||||
|
||||
vision_attentions = output.vision_model_output.attentions
|
||||
self.assertEqual(len(vision_attentions), vision_config.num_hidden_layers)
|
||||
|
||||
# in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||
image_size = to_2tuple(vision_model.config.image_size)
|
||||
patch_size = to_2tuple(vision_model.config.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
seq_len = num_patches + 1
|
||||
self.assertEqual(vision_attentions[0].shape[-3:], (vision_config.num_attention_heads, seq_len, seq_len))
|
||||
|
||||
text_attentions = output.text_model_output.attentions
|
||||
self.assertEqual(len(text_attentions), text_config.num_hidden_layers)
|
||||
|
||||
self.assertEqual(
|
||||
text_attentions[0].shape[-3:],
|
||||
(text_config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]),
|
||||
)
|
||||
|
||||
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
|
||||
|
||||
pt_model.to(torch_device)
|
||||
pt_model.eval()
|
||||
|
||||
# prepare inputs
|
||||
flax_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = FlaxVisionTextDualEncoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = VisionTextDualEncoderModel.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2)
|
||||
|
||||
def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
|
||||
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
|
||||
|
||||
pt_model = VisionTextDualEncoderModel(config)
|
||||
fx_model = FlaxVisionTextDualEncoderModel(config)
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
|
||||
|
||||
def check_equivalence_flax_to_pt(self, vision_config, text_config, inputs_dict):
|
||||
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
|
||||
|
||||
pt_model = VisionTextDualEncoderModel(config)
|
||||
fx_model = FlaxVisionTextDualEncoderModel(config)
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
|
||||
|
||||
def test_model_from_pretrained_configs(self):
|
||||
inputs_dict = self.prepare_config_and_inputs()
|
||||
self.check_model_from_pretrained_configs(**inputs_dict)
|
||||
|
||||
def test_vision_text_dual_encoder_from_pretrained(self):
|
||||
inputs_dict = self.prepare_config_and_inputs()
|
||||
self.check_vision_text_dual_encoder_from_pretrained(**inputs_dict)
|
||||
|
||||
def test_save_load(self):
|
||||
inputs_dict = self.prepare_config_and_inputs()
|
||||
self.check_save_load(**inputs_dict)
|
||||
|
||||
def test_vision_text_output_attention(self):
|
||||
inputs_dict = self.prepare_config_and_inputs()
|
||||
self.check_vision_text_output_attention(**inputs_dict)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_pt_flax_equivalence(self):
|
||||
|
||||
config_inputs_dict = self.prepare_config_and_inputs()
|
||||
vision_config = config_inputs_dict.pop("vision_config")
|
||||
text_config = config_inputs_dict.pop("text_config")
|
||||
|
||||
inputs_dict = config_inputs_dict
|
||||
|
||||
self.check_equivalence_pt_to_flax(vision_config, text_config, inputs_dict)
|
||||
self.check_equivalence_flax_to_pt(vision_config, text_config, inputs_dict)
|
||||
|
||||
@slow
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
model_2, inputs = self.get_pretrained_model_and_inputs()
|
||||
|
||||
outputs = model_2(**inputs)
|
||||
out_2 = outputs[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
model_2.save_pretrained(tmp_dirname)
|
||||
model_1 = FlaxVisionTextDualEncoderModel.from_pretrained(tmp_dirname)
|
||||
|
||||
after_outputs = model_1(**inputs)
|
||||
out_1 = after_outputs[0]
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxViTBertModelTest(VisionTextDualEncoderMixin, unittest.TestCase):
|
||||
def get_pretrained_model_and_inputs(self):
|
||||
model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained(
|
||||
"hf-internal-testing/tiny-random-vit",
|
||||
"hf-internal-testing/tiny-bert",
|
||||
vision_from_pt=True,
|
||||
text_from_pt=True,
|
||||
)
|
||||
batch_size = 13
|
||||
pixel_values = floats_tensor(
|
||||
[
|
||||
batch_size,
|
||||
model.config.vision_config.num_channels,
|
||||
model.config.vision_config.image_size,
|
||||
model.config.vision_config.image_size,
|
||||
]
|
||||
)
|
||||
input_ids = ids_tensor([batch_size, 4], model.config.text_config.vocab_size)
|
||||
attention_mask = random_attention_mask([batch_size, 4])
|
||||
inputs = {"pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask}
|
||||
|
||||
return model, inputs
|
||||
|
||||
def get_vision_text_model(self, vision_config, text_config):
|
||||
vision_model = FlaxViTModel(vision_config)
|
||||
text_model = FlaxBertModel(text_config)
|
||||
return vision_model, text_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
vit_model_tester = FlaxViTModelTester(self)
|
||||
bert_model_tester = FlaxBertModelTester(self)
|
||||
vision_config_and_inputs = vit_model_tester.prepare_config_and_inputs()
|
||||
text_config_and_inputs = bert_model_tester.prepare_config_and_inputs()
|
||||
|
||||
vision_config, pixel_values = vision_config_and_inputs
|
||||
|
||||
text_config, input_ids, token_type_ids, attention_mask = text_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
return {
|
||||
"text_config": text_config,
|
||||
"vision_config": vision_config,
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": attention_mask,
|
||||
"text_config": text_config,
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
|
||||
|
||||
@require_torch
|
||||
class FlaxCLIPVisionBertModelTest(VisionTextDualEncoderMixin, unittest.TestCase):
|
||||
def get_pretrained_model_and_inputs(self):
|
||||
model = FlaxVisionTextDualEncoderModel.from_vision_text_pretrained(
|
||||
"hf-internal-testing/tiny-random-clip",
|
||||
"hf-internal-testing/tiny-bert",
|
||||
vision_from_pt=True,
|
||||
text_from_pt=True,
|
||||
)
|
||||
batch_size = 13
|
||||
pixel_values = floats_tensor(
|
||||
[
|
||||
batch_size,
|
||||
model.config.vision_config.num_channels,
|
||||
model.config.vision_config.image_size,
|
||||
model.config.vision_config.image_size,
|
||||
]
|
||||
)
|
||||
input_ids = ids_tensor([batch_size, 4], model.config.text_config.vocab_size)
|
||||
attention_mask = random_attention_mask([batch_size, 4])
|
||||
inputs = {"pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask}
|
||||
|
||||
return model, inputs
|
||||
|
||||
def get_vision_text_model(self, vision_config, text_config):
|
||||
vision_model = FlaxCLIPVisionModel(vision_config)
|
||||
text_model = FlaxBertModel(text_config)
|
||||
return vision_model, text_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
clip_model_tester = FlaxCLIPVisionModelTester(self)
|
||||
bert_model_tester = FlaxBertModelTester(self)
|
||||
vision_config_and_inputs = clip_model_tester.prepare_config_and_inputs()
|
||||
text_config_and_inputs = bert_model_tester.prepare_config_and_inputs()
|
||||
|
||||
vision_config, pixel_values = vision_config_and_inputs
|
||||
|
||||
text_config, input_ids, token_type_ids, attention_mask = text_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
return {
|
||||
"text_config": text_config,
|
||||
"vision_config": vision_config,
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": attention_mask,
|
||||
"text_config": text_config,
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
|
||||
|
||||
@require_flax
|
||||
@require_vision
|
||||
class FlaxVisionTextDualEncoderIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference(self):
|
||||
model = FlaxVisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian", logit_scale_init_value=1)
|
||||
processor = VisionTextDualEncoderProcessor.from_pretrained("clip-italian/clip-italian")
|
||||
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
inputs = processor(
|
||||
text=["una foto di un gatto", "una foto di un cane"], images=image, padding=True, return_tensors="np"
|
||||
)
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
# verify the logits
|
||||
self.assertEqual(outputs.logits_per_image.shape, (inputs.pixel_values.shape[0], inputs.input_ids.shape[0]))
|
||||
self.assertEqual(
|
||||
outputs.logits_per_text.shape,
|
||||
(inputs.input_ids.shape[0], inputs.pixel_values.shape[0]),
|
||||
)
|
||||
|
||||
expected_logits = np.array([[1.2284727, 0.3104122]])
|
||||
|
||||
self.assertTrue(np.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
|
||||
527
tests/test_modeling_vision_text_dual_encoder.py
Normal file
527
tests/test_modeling_vision_text_dual_encoder.py
Normal file
@@ -0,0 +1,527 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch VisionTextDualEncoder model. """
|
||||
|
||||
|
||||
import collections
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.file_utils import is_flax_available, is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_vision, slow, torch_device
|
||||
|
||||
from .test_modeling_bert import BertModelTester
|
||||
from .test_modeling_clip import CLIPVisionModelTester
|
||||
from .test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
from .test_modeling_deit import DeiTModelTester
|
||||
from .test_modeling_roberta import RobertaModelTester
|
||||
from .test_modeling_vit import ViTModelTester
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
BertModel,
|
||||
CLIPVisionModel,
|
||||
DeiTModel,
|
||||
RobertaModel,
|
||||
VisionTextDualEncoderConfig,
|
||||
VisionTextDualEncoderModel,
|
||||
ViTModel,
|
||||
)
|
||||
|
||||
if is_flax_available():
|
||||
from transformers import FlaxVisionTextDualEncoderModel
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import VisionTextDualEncoderProcessor
|
||||
|
||||
|
||||
# Inspired by
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
|
||||
# From PyTorch internals
|
||||
def to_2tuple(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return (x, x)
|
||||
|
||||
|
||||
@require_torch
|
||||
class VisionTextDualEncoderMixin:
|
||||
def get_vision_text_model(self, config, text_config):
|
||||
pass
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pass
|
||||
|
||||
def get_pretrained_model_and_inputs(self):
|
||||
pass
|
||||
|
||||
def check_model_from_pretrained_configs(
|
||||
self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs
|
||||
):
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
|
||||
|
||||
model = VisionTextDualEncoderModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask)
|
||||
|
||||
self.assertEqual(output["text_embeds"].shape, (input_ids.shape[0], config.projection_dim))
|
||||
self.assertEqual(output["image_embeds"].shape, (pixel_values.shape[0], config.projection_dim))
|
||||
|
||||
def check_vision_text_dual_encoder_model(
|
||||
self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs
|
||||
):
|
||||
vision_model, text_model = self.get_vision_text_model(vision_config, text_config)
|
||||
model = VisionTextDualEncoderModel(vision_model=vision_model, text_model=text_model)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask)
|
||||
|
||||
self.assertEqual(output["text_embeds"].shape, (input_ids.shape[0], model.config.projection_dim))
|
||||
self.assertEqual(output["image_embeds"].shape, (pixel_values.shape[0], model.config.projection_dim))
|
||||
|
||||
def check_vision_text_dual_encoder_from_pretrained(
|
||||
self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs
|
||||
):
|
||||
|
||||
vision_model, text_model = self.get_vision_text_model(vision_config, text_config)
|
||||
kwargs = {"vision_model": vision_model, "text_model": text_model}
|
||||
model = VisionTextDualEncoderModel.from_vision_text_pretrained(**kwargs)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask)
|
||||
|
||||
self.assertEqual(output["text_embeds"].shape, (input_ids.shape[0], model.config.projection_dim))
|
||||
self.assertEqual(output["image_embeds"].shape, (pixel_values.shape[0], model.config.projection_dim))
|
||||
|
||||
def check_save_load(self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs):
|
||||
vision_model, text_model = self.get_vision_text_model(vision_config, text_config)
|
||||
model = VisionTextDualEncoderModel(vision_model=vision_model, text_model=text_model)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask)
|
||||
out_1 = output[0].cpu().numpy()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = VisionTextDualEncoderModel.from_pretrained(tmpdirname).eval()
|
||||
model.to(torch_device)
|
||||
|
||||
after_output = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask)
|
||||
out_2 = after_output[0].cpu().numpy()
|
||||
max_diff = np.amax(np.abs(out_2 - out_1))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
def check_vision_text_output_attention(
|
||||
self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs
|
||||
):
|
||||
vision_model, text_model = self.get_vision_text_model(vision_config, text_config)
|
||||
model = VisionTextDualEncoderModel(vision_model=vision_model, text_model=text_model)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
output = model(
|
||||
input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, output_attentions=True
|
||||
)
|
||||
|
||||
vision_attentions = output.vision_model_output.attentions
|
||||
self.assertEqual(len(vision_attentions), vision_config.num_hidden_layers)
|
||||
|
||||
# in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||
image_size = to_2tuple(vision_model.config.image_size)
|
||||
patch_size = to_2tuple(vision_model.config.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
seq_len = num_patches + 1
|
||||
self.assertEqual(vision_attentions[0].shape[-3:], (vision_config.num_attention_heads, seq_len, seq_len))
|
||||
|
||||
text_attentions = output.text_model_output.attentions
|
||||
self.assertEqual(len(text_attentions), text_config.num_hidden_layers)
|
||||
|
||||
self.assertEqual(
|
||||
text_attentions[0].shape[-3:],
|
||||
(text_config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]),
|
||||
)
|
||||
|
||||
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
||||
diff = np.abs((a - b)).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
|
||||
def check_pt_flax_equivalence(self, pt_model, fx_model, input_ids, attention_mask, pixel_values, **kwargs):
|
||||
|
||||
pt_model.to(torch_device)
|
||||
pt_model.eval()
|
||||
|
||||
# prepare inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values}
|
||||
pt_inputs = inputs_dict
|
||||
flax_inputs = {k: v.numpy() for k, v in pt_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**flax_inputs).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = FlaxVisionTextDualEncoderModel.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**flax_inputs).to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = VisionTextDualEncoderModel.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
pt_model_loaded.to(torch_device)
|
||||
pt_model_loaded.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2)
|
||||
|
||||
def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
|
||||
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
|
||||
|
||||
pt_model = VisionTextDualEncoderModel(config)
|
||||
fx_model = FlaxVisionTextDualEncoderModel(config)
|
||||
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, **inputs_dict)
|
||||
|
||||
def check_equivalence_flax_to_pt(self, vision_config, text_config, inputs_dict):
|
||||
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
|
||||
|
||||
pt_model = VisionTextDualEncoderModel(config)
|
||||
fx_model = FlaxVisionTextDualEncoderModel(config)
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
|
||||
self.check_pt_flax_equivalence(pt_model, fx_model, **inputs_dict)
|
||||
|
||||
def test_vision_text_dual_encoder_model(self):
|
||||
inputs_dict = self.prepare_config_and_inputs()
|
||||
self.check_vision_text_dual_encoder_model(**inputs_dict)
|
||||
|
||||
def test_model_from_pretrained_configs(self):
|
||||
inputs_dict = self.prepare_config_and_inputs()
|
||||
self.check_model_from_pretrained_configs(**inputs_dict)
|
||||
|
||||
def test_vision_text_dual_encoder_from_pretrained(self):
|
||||
inputs_dict = self.prepare_config_and_inputs()
|
||||
self.check_vision_text_dual_encoder_from_pretrained(**inputs_dict)
|
||||
|
||||
def test_save_load(self):
|
||||
inputs_dict = self.prepare_config_and_inputs()
|
||||
self.check_save_load(**inputs_dict)
|
||||
|
||||
def test_vision_text_output_attention(self):
|
||||
inputs_dict = self.prepare_config_and_inputs()
|
||||
self.check_vision_text_output_attention(**inputs_dict)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_pt_flax_equivalence(self):
|
||||
|
||||
config_inputs_dict = self.prepare_config_and_inputs()
|
||||
vision_config = config_inputs_dict.pop("vision_config")
|
||||
text_config = config_inputs_dict.pop("text_config")
|
||||
|
||||
inputs_dict = config_inputs_dict
|
||||
|
||||
self.check_equivalence_pt_to_flax(vision_config, text_config, inputs_dict)
|
||||
self.check_equivalence_flax_to_pt(vision_config, text_config, inputs_dict)
|
||||
|
||||
@slow
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
model_2, inputs = self.get_pretrained_model_and_inputs()
|
||||
model_2.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model_2(**inputs)
|
||||
out_2 = outputs[0].cpu().numpy()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
model_2.save_pretrained(tmp_dirname)
|
||||
model_1 = VisionTextDualEncoderModel.from_pretrained(tmp_dirname)
|
||||
model_1.to(torch_device)
|
||||
|
||||
after_outputs = model_1(**inputs)
|
||||
out_1 = after_outputs[0].cpu().numpy()
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
|
||||
@require_torch
|
||||
class ViTBertModelTest(VisionTextDualEncoderMixin, unittest.TestCase):
|
||||
def get_pretrained_model_and_inputs(self):
|
||||
model = VisionTextDualEncoderModel.from_vision_text_pretrained(
|
||||
"hf-internal-testing/tiny-random-vit", "hf-internal-testing/tiny-bert"
|
||||
)
|
||||
batch_size = 13
|
||||
pixel_values = floats_tensor(
|
||||
[
|
||||
batch_size,
|
||||
model.vision_model.config.num_channels,
|
||||
model.vision_model.config.image_size,
|
||||
model.vision_model.config.image_size,
|
||||
]
|
||||
)
|
||||
input_ids = ids_tensor([batch_size, 4], model.text_model.config.vocab_size)
|
||||
attention_mask = random_attention_mask([batch_size, 4])
|
||||
inputs = {"pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask}
|
||||
|
||||
return model, inputs
|
||||
|
||||
def get_vision_text_model(self, vision_config, text_config):
|
||||
vision_model = ViTModel(vision_config).eval()
|
||||
text_model = BertModel(text_config).eval()
|
||||
return vision_model, text_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
vit_model_tester = ViTModelTester(self)
|
||||
bert_model_tester = BertModelTester(self)
|
||||
vision_config_and_inputs = vit_model_tester.prepare_config_and_inputs()
|
||||
text_config_and_inputs = bert_model_tester.prepare_config_and_inputs()
|
||||
|
||||
vision_config, pixel_values, _ = vision_config_and_inputs
|
||||
|
||||
(
|
||||
text_config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = text_config_and_inputs
|
||||
|
||||
return {
|
||||
"text_config": text_config,
|
||||
"vision_config": vision_config,
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": input_mask,
|
||||
"text_config": text_config,
|
||||
"input_ids": input_ids,
|
||||
"text_token_type_ids": token_type_ids,
|
||||
"text_sequence_labels": sequence_labels,
|
||||
"text_token_labels": token_labels,
|
||||
"text_choice_labels": choice_labels,
|
||||
}
|
||||
|
||||
|
||||
@require_torch
|
||||
class DeiTRobertaModelTest(VisionTextDualEncoderMixin, unittest.TestCase):
|
||||
def get_pretrained_model_and_inputs(self):
|
||||
model = VisionTextDualEncoderModel.from_vision_text_pretrained(
|
||||
"hf-internal-testing/tiny-random-deit", "hf-internal-testing/tiny-random-roberta"
|
||||
)
|
||||
batch_size = 13
|
||||
pixel_values = floats_tensor(
|
||||
[
|
||||
batch_size,
|
||||
model.vision_model.config.num_channels,
|
||||
model.vision_model.config.image_size,
|
||||
model.vision_model.config.image_size,
|
||||
]
|
||||
)
|
||||
input_ids = ids_tensor([batch_size, 4], model.text_model.config.vocab_size)
|
||||
attention_mask = random_attention_mask([batch_size, 4])
|
||||
inputs = {"pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask}
|
||||
|
||||
return model, inputs
|
||||
|
||||
def check_vision_text_output_attention(
|
||||
self, text_config, input_ids, attention_mask, vision_config, pixel_values=None, **kwargs
|
||||
):
|
||||
vision_model, text_model = self.get_vision_text_model(vision_config, text_config)
|
||||
model = VisionTextDualEncoderModel(vision_model=vision_model, text_model=text_model)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
output = model(
|
||||
input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, output_attentions=True
|
||||
)
|
||||
|
||||
vision_attentions = output.vision_model_output.attentions
|
||||
self.assertEqual(len(vision_attentions), vision_config.num_hidden_layers)
|
||||
|
||||
# in DEiT, the seq_len equals the number of patches + 2 (we add 2 for the [CLS] and distillation tokens)
|
||||
image_size = to_2tuple(vision_model.config.image_size)
|
||||
patch_size = to_2tuple(vision_model.config.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
seq_len = num_patches + 2
|
||||
self.assertEqual(vision_attentions[0].shape[-3:], (vision_config.num_attention_heads, seq_len, seq_len))
|
||||
|
||||
text_attentions = output.text_model_output.attentions
|
||||
self.assertEqual(len(text_attentions), text_config.num_hidden_layers)
|
||||
|
||||
self.assertEqual(
|
||||
text_attentions[0].shape[-3:],
|
||||
(text_config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]),
|
||||
)
|
||||
|
||||
def get_vision_text_model(self, vision_config, text_config):
|
||||
vision_model = DeiTModel(vision_config).eval()
|
||||
text_model = RobertaModel(text_config).eval()
|
||||
return vision_model, text_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
vit_model_tester = DeiTModelTester(self)
|
||||
bert_model_tester = RobertaModelTester(self)
|
||||
vision_config_and_inputs = vit_model_tester.prepare_config_and_inputs()
|
||||
text_config_and_inputs = bert_model_tester.prepare_config_and_inputs()
|
||||
|
||||
vision_config, pixel_values, _ = vision_config_and_inputs
|
||||
|
||||
(
|
||||
text_config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = text_config_and_inputs
|
||||
|
||||
return {
|
||||
"text_config": text_config,
|
||||
"vision_config": vision_config,
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": input_mask,
|
||||
"text_config": text_config,
|
||||
"input_ids": input_ids,
|
||||
"text_token_type_ids": token_type_ids,
|
||||
"text_sequence_labels": sequence_labels,
|
||||
"text_token_labels": token_labels,
|
||||
"text_choice_labels": choice_labels,
|
||||
}
|
||||
|
||||
# skip as DeiT is not available in Flax
|
||||
def test_pt_flax_equivalence(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class CLIPVisionBertModelTest(VisionTextDualEncoderMixin, unittest.TestCase):
|
||||
def get_pretrained_model_and_inputs(self):
|
||||
model = VisionTextDualEncoderModel.from_vision_text_pretrained(
|
||||
"hf-internal-testing/tiny-random-clip", "hf-internal-testing/tiny-bert"
|
||||
)
|
||||
batch_size = 13
|
||||
pixel_values = floats_tensor(
|
||||
[
|
||||
batch_size,
|
||||
model.vision_model.config.num_channels,
|
||||
model.vision_model.config.image_size,
|
||||
model.vision_model.config.image_size,
|
||||
]
|
||||
)
|
||||
input_ids = ids_tensor([batch_size, 4], model.text_model.config.vocab_size)
|
||||
attention_mask = random_attention_mask([batch_size, 4])
|
||||
inputs = {"pixel_values": pixel_values, "input_ids": input_ids, "attention_mask": attention_mask}
|
||||
|
||||
return model, inputs
|
||||
|
||||
def get_vision_text_model(self, vision_config, text_config):
|
||||
vision_model = CLIPVisionModel(vision_config).eval()
|
||||
text_model = BertModel(text_config).eval()
|
||||
return vision_model, text_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
clip_model_tester = CLIPVisionModelTester(self)
|
||||
bert_model_tester = BertModelTester(self)
|
||||
vision_config_and_inputs = clip_model_tester.prepare_config_and_inputs()
|
||||
text_config_and_inputs = bert_model_tester.prepare_config_and_inputs()
|
||||
|
||||
vision_config, pixel_values = vision_config_and_inputs
|
||||
|
||||
(
|
||||
text_config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = text_config_and_inputs
|
||||
|
||||
return {
|
||||
"text_config": text_config,
|
||||
"vision_config": vision_config,
|
||||
"pixel_values": pixel_values,
|
||||
"attention_mask": input_mask,
|
||||
"text_config": text_config,
|
||||
"input_ids": input_ids,
|
||||
"text_token_type_ids": token_type_ids,
|
||||
"text_sequence_labels": sequence_labels,
|
||||
"text_token_labels": token_labels,
|
||||
"text_choice_labels": choice_labels,
|
||||
}
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
class VisionTextDualEncoderIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference(self):
|
||||
model = VisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian", logit_scale_init_value=1)
|
||||
processor = VisionTextDualEncoderProcessor.from_pretrained("clip-italian/clip-italian")
|
||||
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
inputs = processor(
|
||||
text=["una foto di un gatto", "una foto di un cane"], images=image, padding=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
# verify the logits
|
||||
self.assertEqual(outputs.logits_per_image.shape, (inputs.pixel_values.shape[0], inputs.input_ids.shape[0]))
|
||||
self.assertEqual(
|
||||
outputs.logits_per_text.shape,
|
||||
(inputs.input_ids.shape[0], inputs.pixel_values.shape[0]),
|
||||
)
|
||||
|
||||
expected_logits = torch.tensor([[1.2284727, 0.3104122]])
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
|
||||
170
tests/test_processor_vision_text_dual_encoder.py
Normal file
170
tests/test_processor_vision_text_dual_encoder.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import BertTokenizerFast
|
||||
from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_vision_available
|
||||
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES, BertTokenizer
|
||||
from transformers.testing_utils import require_tokenizers, require_vision
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import VisionTextDualEncoderProcessor, ViTFeatureExtractor
|
||||
|
||||
|
||||
@require_tokenizers
|
||||
@require_vision
|
||||
class VisionTextDualEncoderProcessorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
# fmt: off
|
||||
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "want", "##want", "##ed", "wa", "un", "runn", "##ing", ",", "low", "lowest"]
|
||||
# fmt: on
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
feature_extractor_map = {
|
||||
"do_resize": True,
|
||||
"size": 18,
|
||||
"do_normalize": True,
|
||||
"image_mean": [0.5, 0.5, 0.5],
|
||||
"image_std": [0.5, 0.5, 0.5],
|
||||
}
|
||||
self.feature_extractor_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME)
|
||||
with open(self.feature_extractor_file, "w", encoding="utf-8") as fp:
|
||||
json.dump(feature_extractor_map, fp)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_feature_extractor(self, **kwargs):
|
||||
return ViTFeatureExtractor.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def prepare_image_inputs(self):
|
||||
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
|
||||
or a list of PyTorch tensors if one specifies torchify=True.
|
||||
"""
|
||||
|
||||
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
|
||||
|
||||
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
|
||||
|
||||
return image_inputs
|
||||
|
||||
def test_save_load_pretrained_default(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
|
||||
processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
processor = VisionTextDualEncoderProcessor.from_pretrained(self.tmpdirname)
|
||||
|
||||
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
|
||||
self.assertIsInstance(processor.tokenizer, (BertTokenizer, BertTokenizerFast))
|
||||
|
||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
|
||||
self.assertIsInstance(processor.feature_extractor, ViTFeatureExtractor)
|
||||
|
||||
def test_save_load_pretrained_additional_features(self):
|
||||
processor = VisionTextDualEncoderProcessor(
|
||||
tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor()
|
||||
)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
|
||||
feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0)
|
||||
|
||||
processor = VisionTextDualEncoderProcessor.from_pretrained(
|
||||
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
|
||||
)
|
||||
|
||||
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
|
||||
self.assertIsInstance(processor.tokenizer, (BertTokenizer, BertTokenizerFast))
|
||||
|
||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
||||
self.assertIsInstance(processor.feature_extractor, ViTFeatureExtractor)
|
||||
|
||||
def test_feature_extractor(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
input_feat_extract = feature_extractor(image_input, return_tensors="np")
|
||||
input_processor = processor(images=image_input, return_tensors="np")
|
||||
|
||||
for key in input_feat_extract.keys():
|
||||
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
||||
|
||||
def test_tokenizer(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
input_str = "lower newer"
|
||||
|
||||
encoded_processor = processor(text=input_str)
|
||||
|
||||
encoded_tok = tokenizer(input_str)
|
||||
|
||||
for key in encoded_tok.keys():
|
||||
self.assertListEqual(encoded_tok[key], encoded_processor[key])
|
||||
|
||||
def test_processor(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input)
|
||||
|
||||
self.assertListEqual(list(inputs.keys()), ["input_ids", "token_type_ids", "attention_mask", "pixel_values"])
|
||||
|
||||
# test if it raises when no input is passed
|
||||
with self.assertRaises(ValueError):
|
||||
processor()
|
||||
|
||||
def test_tokenizer_decode(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
|
||||
|
||||
decoded_processor = processor.batch_decode(predicted_ids)
|
||||
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
||||
|
||||
self.assertListEqual(decoded_tok, decoded_processor)
|
||||
@@ -94,6 +94,8 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
|
||||
"test_modeling_tf_xlm_roberta.py",
|
||||
"test_modeling_xlm_prophetnet.py",
|
||||
"test_modeling_xlm_roberta.py",
|
||||
"test_modeling_vision_text_dual_encoder.py",
|
||||
"test_modeling_flax_vision_text_dual_encoder.py",
|
||||
]
|
||||
|
||||
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
|
||||
|
||||
Reference in New Issue
Block a user