Add BridgeTower model (#20775)
* Commit with BTModel and latest HF code * Placeholder classes for BTForMLM and BTForITR * Importing Bert classes from transformers * Removed objectives.py and dist_utils.py * Removed swin_transformer.py * Add image normalization, BridgeTowerForImageAndTextRetrieval * Add center_crop * Removing bert tokenizer and LCI references * Tested config loading from HF transformers hub * Removed state_dict updates and added path to hub * Enable center crop * Getting image_size from config, renaming num_heads and num_layers * Handling max_length in BridgeTowerProcessor * Add BridgeTowerForMaskedLM * Add doc string for BridgeTowerConfig * Add doc strings for BT config, processor, image processor * Adding docs, removed swin * Removed convert_bridgetower_original_to_pytorch.py * Added doc files for bridgetower, removed is_vision * Add support attention_mask=None and BridgeTowerModelOutput * Fix formatting * Fixes with 'make style', 'make quality', 'make fixup' * Remove downstream tasks from BridgeTowerModel * Formatting fixes, add return_dict to BT models * Clean up after doc_test * Update BTModelOutput return type, fix todo in doc * Remove loss_names from init * implement tests and update tuples returned by models * Add image reference to bridgetower.mdx * after make fix-copies, make fixup, make style, make quality, make repo-consistency * Rename class names with BridgeTower prefix * Fix for image_size in BTImageProcessor * implement feature extraction bridgetower tests * Update image_mean and image_std to be list * remove unused import * Removed old comments * Rework CLIP * update config in tests followed config update * Formatting fixes * Add copied from for BridgeTowerPredictionHeadTransform * Update bridgetower.mdx * Update test_feature_extraction_bridgetower.py * Update bridgetower.mdx * BridgeTowerForMaskedLM is conditioned on image too * Add BridgeTowerForMaskedLM * Fixes * Call post_init to init weights * Move freeze layers into method * Remove BTFeatureExtractor, add BT under multimodal models * Remove BTFeatureExtractor, add BT under multimodal models * Code review feedback - cleanup * Rename variables * Formatting and style to PR review feedback * Move center crop after resize * Use named parameters * Style fix for modeling_bridgetower.py * Update docs/source/en/model_doc/bridgetower.mdx Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/bridgetower.mdx Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/bridgetower.mdx Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/bridgetower/modeling_bridgetower.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/bridgetower/modeling_bridgetower.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/bridgetower.mdx Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/models/bridgetower/modeling_bridgetower.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Rename config params, copy BERT classes, clean comments * Cleanup irtr * Replace Roberta imports, add BTTextConfig and Model * Update docs, add visionconfig, consistent arg names * make fixup * Comments for forward in BTModel and make fixup * correct tests * Remove inconsistent roberta copied from * Add BridgeTowerTextModel to dummy_pt_objects.py * Add BridgeTowerTextModel to IGNORE_NON_TESTED * Update docs for BT Text and Vision Configs * Treat BridgeTowerTextModel as a private model * BridgeTowerTextModel as private * Run make fix-copies * Adding BTTextModel to PRIVATE_MODELS * Fix for issue with BT Text and Image configs * make style changes * Update README_ja.md Add から to BridgeTower's description * Clean up config, .mdx and arg names * Fix init_weights. Remove nn.Sequential * Formatting and style fixes * Re-add tie_word_embeddings in config * update test implementation * update style * remove commented out * fix style * Update README with abs for BridgeTower * fix style * fix mdx file * Update bridgetower.mdx * Update img src in bridgetower.mdx * Update README.md * Update README.md * resolve style failed * Update _toctree.yml * Update README_ja.md * Removed mlp_ratio, rename feats, rename BTCLIPModel * Replace BTCLIP with BTVisionModel,pass in vision_config to BTVisionModel * Add test_initialization support * Add support for output_hidden_states * Update support for output_hidden_states * Add support for output_attentions * Add docstring for output_hidden_states * update tests * add bridgetowervisionmodel as private model * rerun the PR test * Remove model_type, pass configs to classes, renames * Change self.device to use weight device * Remove image_size * Style check fixes * Add hidden_size and num_hidden_layers to BridgeTowerTransformer * Update device setting * cosmetic update * trigger test again * trigger tests again * Update test_modeling_bridgetower.py trigger tests again * Update test_modeling_bridgetower.py * minor update * re-trigger tests * Update docs/source/en/model_doc/bridgetower.mdx Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Remove pad, update max_text_len, doc cleanup, pass eps to LayerNorm * Added copied to, some more review feedback * make fixup * Use BridgeTowerVisionEmbeddings * Code cleanup * Fixes for BridgeTowerVisionEmbeddings * style checks * re-tests * fix embedding * address comment on init file * retrigger tests * update import prepare_image_inputs * update test_image_processing_bridgetower.py to reflect test_image_processing_common.py * retrigger tests Co-authored-by: Shaoyen Tseng <shao-yen.tseng@intel.com> Co-authored-by: Tiep Le <tiep.le@intel.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Tiep Le <97980157+tileintel@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
39799fbf85
commit
3a6e4a221c
0
tests/models/bridgetower/__init__.py
Normal file
0
tests/models/bridgetower/__init__.py
Normal file
258
tests/models/bridgetower/test_image_processing_bridgetower.py
Normal file
258
tests/models/bridgetower/test_image_processing_bridgetower.py
Normal file
@@ -0,0 +1,258 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The Intel Labs Team Authors, The Microsoft Research Team Authors and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingSavingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import BridgeTowerImageProcessor
|
||||
|
||||
|
||||
class BridgeTowerImageProcessingTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
do_resize: bool = True,
|
||||
size: Dict[str, int] = None,
|
||||
size_divisor: int = 32,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: Union[int, float] = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
do_center_crop: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = [0.48145466, 0.4578275, 0.40821073],
|
||||
image_std: Optional[Union[float, List[float]]] = [0.26862954, 0.26130258, 0.27577711],
|
||||
do_pad: bool = True,
|
||||
batch_size=7,
|
||||
min_resolution=30,
|
||||
max_resolution=400,
|
||||
num_channels=3,
|
||||
):
|
||||
self.parent = parent
|
||||
self.do_resize = do_resize
|
||||
self.size = size if size is not None else {"shortest_edge": 288}
|
||||
self.size_divisor = size_divisor
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.do_center_crop = do_center_crop
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_pad = do_pad
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_normalize": self.do_normalize,
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"size_divisor": self.size_divisor,
|
||||
}
|
||||
|
||||
def get_expected_values(self, image_inputs, batched=False):
|
||||
"""
|
||||
This function computes the expected height and width when providing images to BridgeTowerImageProcessor,
|
||||
assuming do_resize is set to True with a scalar size and size_divisor.
|
||||
"""
|
||||
if not batched:
|
||||
size = self.size["shortest_edge"]
|
||||
image = image_inputs[0]
|
||||
if isinstance(image, Image.Image):
|
||||
w, h = image.size
|
||||
else:
|
||||
h, w = image.shape[1], image.shape[2]
|
||||
scale = size / min(w, h)
|
||||
if h < w:
|
||||
newh, neww = size, scale * w
|
||||
else:
|
||||
newh, neww = scale * h, size
|
||||
|
||||
max_size = int((1333 / 800) * size)
|
||||
if max(newh, neww) > max_size:
|
||||
scale = max_size / max(newh, neww)
|
||||
newh = newh * scale
|
||||
neww = neww * scale
|
||||
|
||||
newh, neww = int(newh + 0.5), int(neww + 0.5)
|
||||
expected_height, expected_width = (
|
||||
newh // self.size_divisor * self.size_divisor,
|
||||
neww // self.size_divisor * self.size_divisor,
|
||||
)
|
||||
|
||||
else:
|
||||
expected_values = []
|
||||
for image in image_inputs:
|
||||
expected_height, expected_width = self.get_expected_values([image])
|
||||
expected_values.append((expected_height, expected_width))
|
||||
expected_height = max(expected_values, key=lambda item: item[0])[0]
|
||||
expected_width = max(expected_values, key=lambda item: item[1])[1]
|
||||
|
||||
return expected_height, expected_width
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class BridgeTowerImageProcessingTest(ImageProcessingSavingTestMixin, unittest.TestCase):
|
||||
image_processing_class = BridgeTowerImageProcessor if is_vision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
self.image_processor_tester = BridgeTowerImageProcessingTester(self)
|
||||
|
||||
@property
|
||||
def image_processor_dict(self):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "size_divisor"))
|
||||
|
||||
def test_batch_feature(self):
|
||||
pass
|
||||
|
||||
def test_call_pil(self):
|
||||
# Initialize feature_extractor
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PIL images
|
||||
image_inputs = prepare_image_inputs(self.image_processor_tester, equal_resolution=False)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, Image.Image)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
|
||||
expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs)
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(1, self.image_processor_tester.num_channels, expected_height, expected_width),
|
||||
)
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
|
||||
expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs, batched=True)
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
self.image_processor_tester.batch_size,
|
||||
self.image_processor_tester.num_channels,
|
||||
expected_height,
|
||||
expected_width,
|
||||
),
|
||||
)
|
||||
|
||||
def test_call_numpy(self):
|
||||
# Initialize feature_extractor
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
image_inputs = prepare_image_inputs(self.image_processor_tester, equal_resolution=False, numpify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
|
||||
expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs)
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(1, self.image_processor_tester.num_channels, expected_height, expected_width),
|
||||
)
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
|
||||
expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs, batched=True)
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
self.image_processor_tester.batch_size,
|
||||
self.image_processor_tester.num_channels,
|
||||
expected_height,
|
||||
expected_width,
|
||||
),
|
||||
)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
# Initialize feature_extractor
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = prepare_image_inputs(self.image_processor_tester, equal_resolution=False, torchify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
|
||||
expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs)
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(1, self.image_processor_tester.num_channels, expected_height, expected_width),
|
||||
)
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
|
||||
expected_height, expected_width = self.image_processor_tester.get_expected_values(image_inputs, batched=True)
|
||||
self.assertEqual(
|
||||
encoded_images.shape,
|
||||
(
|
||||
self.image_processor_tester.batch_size,
|
||||
self.image_processor_tester.num_channels,
|
||||
expected_height,
|
||||
expected_width,
|
||||
),
|
||||
)
|
||||
|
||||
def test_equivalence_pad_and_create_pixel_mask(self):
|
||||
# Initialize feature_extractors
|
||||
image_processing_1 = self.image_processing_class(**self.image_processor_dict)
|
||||
image_processing_2 = self.image_processing_class(do_resize=False, do_normalize=False, do_rescale=False)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = prepare_image_inputs(self.image_processor_tester, equal_resolution=False, torchify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
|
||||
# Test whether the method "pad_and_return_pixel_mask" and calling the image processor return the same tensors
|
||||
encoded_images_with_method = image_processing_1.pad_and_create_pixel_mask(image_inputs, return_tensors="pt")
|
||||
encoded_images = image_processing_2(image_inputs, return_tensors="pt")
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(encoded_images_with_method["pixel_values"], encoded_images["pixel_values"], atol=1e-4)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(encoded_images_with_method["pixel_mask"], encoded_images["pixel_mask"], atol=1e-4)
|
||||
)
|
||||
409
tests/models/bridgetower/test_modeling_bridgetower.py
Normal file
409
tests/models/bridgetower/test_modeling_bridgetower.py
Normal file
@@ -0,0 +1,409 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The Intel Labs Team Authors, The Microsoft Research Team Authors and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch BridgeTower model. """
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import BridgeTowerConfig, is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM, BridgeTowerModel
|
||||
from transformers.models.bridgetower.modeling_bridgetower import BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_10
|
||||
else:
|
||||
is_torch_greater_or_equal_than_1_10 = False
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import BridgeTowerProcessor
|
||||
|
||||
|
||||
class BridgeTowerModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
share_cross_modal_transformer_layers=True,
|
||||
drop_rate=0.1,
|
||||
head_hidden_scale=2,
|
||||
hidden_act="gelu",
|
||||
hidden_size=768,
|
||||
initializer_factor=1,
|
||||
is_encoder_decoder=False,
|
||||
layer_norm_eps=1e-05,
|
||||
share_link_tower_layers=False,
|
||||
link_tower_type="add",
|
||||
num_attention_heads=12,
|
||||
num_hidden_layers=6,
|
||||
tie_word_embeddings=False,
|
||||
init_layernorm_from_vision_encoder=False,
|
||||
output_hidden_states=False,
|
||||
text_config=None,
|
||||
vision_config=None,
|
||||
image_size=288,
|
||||
):
|
||||
self.parent = parent
|
||||
self.share_cross_modal_transformer_layers = share_cross_modal_transformer_layers
|
||||
self.drop_rate = drop_rate
|
||||
self.head_hidden_scale = head_hidden_scale
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_size = hidden_size
|
||||
self.initializer_factor = initializer_factor
|
||||
self.is_encoder_decoder = is_encoder_decoder
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.share_link_tower_layers = share_link_tower_layers
|
||||
self.link_tower_type = link_tower_type
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.init_layernorm_from_vision_encoder = init_layernorm_from_vision_encoder
|
||||
self.vocab_size = 50265
|
||||
self.num_channels = 3
|
||||
self.seq_length = 4
|
||||
self.num_image_features = 325
|
||||
self.batch_size = 1
|
||||
self.image_size = image_size
|
||||
self.is_training = False
|
||||
self.expected_num_hidden_layers = 32
|
||||
self.output_hidden_states = output_hidden_states
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
pixel_mask = random_attention_mask([self.batch_size, self.image_size, self.image_size])
|
||||
config = self.get_config()
|
||||
return (config, input_ids, attention_mask, pixel_values, pixel_mask)
|
||||
|
||||
def get_config(self):
|
||||
return BridgeTowerConfig(
|
||||
share_cross_modal_transformer_layers=self.share_cross_modal_transformer_layers,
|
||||
drop_rate=self.drop_rate,
|
||||
head_hidden_scale=self.head_hidden_scale,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_size=self.hidden_size,
|
||||
initializer_factor=self.initializer_factor,
|
||||
image_size=self.image_size,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
layer_norm_eps=self.layer_norm_eps,
|
||||
share_link_tower_layers=self.share_link_tower_layers,
|
||||
link_tower_type=self.link_tower_type,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
tie_word_embeddings=self.tie_word_embeddings,
|
||||
init_layernorm_from_vision_encoder=self.init_layernorm_from_vision_encoder,
|
||||
num_channels=self.num_channels,
|
||||
output_hidden_states=self.output_hidden_states,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
pixel_values,
|
||||
pixel_mask,
|
||||
):
|
||||
model = BridgeTowerModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, pixel_mask=pixel_mask)
|
||||
result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
|
||||
self.parent.assertEqual(result["text_features"].shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(
|
||||
result["image_features"].shape, (self.batch_size, self.num_image_features, self.hidden_size)
|
||||
)
|
||||
self.parent.assertEqual(result["pooler_output"].shape, (self.batch_size, 2 * self.hidden_size))
|
||||
|
||||
def create_and_check_for_image_and_text_retrieval(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
pixel_values,
|
||||
pixel_mask,
|
||||
):
|
||||
bridgetower_itm_output_last_dimension = 2
|
||||
|
||||
model = BridgeTowerForImageAndTextRetrieval(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, pixel_mask=pixel_mask)
|
||||
result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
|
||||
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, bridgetower_itm_output_last_dimension))
|
||||
|
||||
def create_and_check_for_masked_language_modeling(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
pixel_values,
|
||||
pixel_mask,
|
||||
):
|
||||
model = BridgeTowerForMaskedLM(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, pixel_mask=pixel_mask)
|
||||
result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
|
||||
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, attention_mask, pixel_values, pixel_mask) = config_and_inputs
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
"pixel_mask": pixel_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_10, "BridgeTower is only available in torch v1.10+")
|
||||
class BridgeTowerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(BridgeTowerModel, BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM) if is_torch_available() else ()
|
||||
)
|
||||
|
||||
is_training = False
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_resize_embeddings = False
|
||||
has_attentions = False
|
||||
|
||||
# function to extract meaningful tensor from output per different model_class
|
||||
def extract_output(self, outputs, model_class):
|
||||
return outputs["pooler_output"] if model_class == "BridgeTowerModel" else outputs["logits"]
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BridgeTowerModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=50265)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_for_image_and_text_retrieval(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_image_and_text_retrieval(*config_and_inputs)
|
||||
|
||||
def test_for_masked_language_modeling(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_masked_language_modeling(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = BridgeTowerModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
# Override as extracting meaningful tensor from output is different for BridgeTower
|
||||
def test_save_load(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**input_dict)
|
||||
|
||||
out_2 = self.extract_output(outputs, model_class.__name__)
|
||||
out_2 = out_2.cpu().numpy()
|
||||
out_2[np.isnan(out_2)] = 0
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
model.to(torch_device)
|
||||
with torch.no_grad():
|
||||
after_outputs = model(**input_dict)
|
||||
|
||||
# Make sure we don't have nans
|
||||
out_1 = self.extract_output(after_outputs, model_class.__name__)
|
||||
out_1 = out_1.cpu().numpy()
|
||||
out_1[np.isnan(out_1)] = 0
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
# Override this as `hidden states output` is different for BridgeTower
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
hidden_states_text, hidden_states_vision, hidden_states_cross = (
|
||||
outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
||||
)
|
||||
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
||||
)
|
||||
self.assertEqual(
|
||||
sum((len(hidden_states_text), len(hidden_states_vision), len(hidden_states_cross))),
|
||||
expected_num_layers,
|
||||
)
|
||||
|
||||
seq_length = self.model_tester.seq_length
|
||||
num_image_features = self.model_tester.num_image_features
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states_text[0].shape[-2:]),
|
||||
[seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
self.assertListEqual(
|
||||
list(hidden_states_vision[0].shape),
|
||||
[num_image_features, 1, self.model_tester.hidden_size],
|
||||
)
|
||||
self.assertListEqual(
|
||||
list(hidden_states_cross[0][0].shape[-2:]),
|
||||
[seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
self.assertListEqual(
|
||||
list(hidden_states_cross[0][1].shape[-2:]),
|
||||
[num_image_features, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# Override as `hidden states output` is different for BridgeTower
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
# no need to test all models as different heads yield the same functionality
|
||||
model_class = self.all_model_classes[0]
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
output = outputs[0]
|
||||
|
||||
# Encoder-/Decoder-only models
|
||||
hidden_states = outputs.hidden_states[0][0]
|
||||
hidden_states.retain_grad()
|
||||
|
||||
if self.has_attentions:
|
||||
attentions = outputs.attentions[0][0]
|
||||
attentions.retain_grad()
|
||||
|
||||
output.flatten()[0].backward(retain_graph=True)
|
||||
|
||||
self.assertIsNotNone(hidden_states.grad)
|
||||
|
||||
if self.has_attentions:
|
||||
self.assertIsNotNone(attentions.grad)
|
||||
|
||||
@unittest.skip(reason="""Bridge Tower does not have input/output embeddings. So this test is not applicable.""")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="""Bridge Tower does not have input/output embeddings. Thus this test is not applicable.""")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
return image
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_10, "BridgeTower is only available in torch v1.10+")
|
||||
class BridgeTowerModelIntegrationTest(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_processor(self):
|
||||
return (
|
||||
BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-base-itm-mlm")
|
||||
if is_vision_available()
|
||||
else None
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_image_and_text_retrieval(self):
|
||||
model = BridgeTowerForImageAndTextRetrieval.from_pretrained("BridgeTower/bridgetower-base-itm-mlm").to(
|
||||
torch_device
|
||||
)
|
||||
model.eval()
|
||||
processor = self.default_processor
|
||||
image = prepare_img()
|
||||
text = "a bunch of cats laying on a tower."
|
||||
inputs = processor(image, text, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size([1, 2])
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
self.assertTrue(outputs.logits[0, 1].item() > outputs.logits[0, 0].item())
|
||||
|
||||
@slow
|
||||
def test_masked_language_modeling(self):
|
||||
model = BridgeTowerForMaskedLM.from_pretrained("BridgeTower/bridgetower-base-itm-mlm").to(torch_device)
|
||||
model.eval()
|
||||
processor = self.default_processor
|
||||
image = prepare_img()
|
||||
text = "a bunch of <mask> laying on a tower."
|
||||
inputs = processor(image, text, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size([1, 11, 50265])
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
# verify predicted word
|
||||
predicted_id = outputs.logits.argmax(dim=-1).squeeze(0).tolist()[4]
|
||||
self.assertTrue(processor.decode([predicted_id]) == " cats")
|
||||
Reference in New Issue
Block a user