[WIP] Add BridgeTowerForContrastiveLearning (#21964)
* Add BridgeTower for ITC * Fix review feedback * Rename BridgeTowerForITC, cleanup * Fix style and quality * implement tests --------- Co-authored-by: Tiep Le <97980157+tileintel@users.noreply.github.com> Co-authored-by: Tiep Le <tiep.le@intel.com>
This commit is contained in:
committed by
GitHub
parent
edea08a6b0
commit
de81adf978
@@ -24,14 +24,25 @@ from transformers.testing_utils import require_torch, require_vision, slow, torc
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
from ...test_modeling_common import (
|
||||
ModelTesterMixin,
|
||||
_config_zero_init,
|
||||
floats_tensor,
|
||||
ids_tensor,
|
||||
random_attention_mask,
|
||||
)
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM, BridgeTowerModel
|
||||
from transformers import (
|
||||
BridgeTowerForContrastiveLearning,
|
||||
BridgeTowerForImageAndTextRetrieval,
|
||||
BridgeTowerForMaskedLM,
|
||||
BridgeTowerModel,
|
||||
)
|
||||
from transformers.models.bridgetower.modeling_bridgetower import BRIDGETOWER_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_10
|
||||
else:
|
||||
@@ -65,6 +76,8 @@ class BridgeTowerModelTester:
|
||||
text_config=None,
|
||||
vision_config=None,
|
||||
image_size=288,
|
||||
contrastive_hidden_size=512,
|
||||
logit_scale_init_value=2.6592,
|
||||
):
|
||||
self.parent = parent
|
||||
self.share_cross_modal_transformer_layers = share_cross_modal_transformer_layers
|
||||
@@ -90,6 +103,8 @@ class BridgeTowerModelTester:
|
||||
self.is_training = False
|
||||
self.expected_num_hidden_layers = 32
|
||||
self.output_hidden_states = output_hidden_states
|
||||
self.contrastive_hidden_size = contrastive_hidden_size
|
||||
self.logit_scale_init_value = logit_scale_init_value
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@@ -118,6 +133,8 @@ class BridgeTowerModelTester:
|
||||
init_layernorm_from_vision_encoder=self.init_layernorm_from_vision_encoder,
|
||||
num_channels=self.num_channels,
|
||||
output_hidden_states=self.output_hidden_states,
|
||||
contrastive_hidden_size=self.contrastive_hidden_size,
|
||||
logit_scale_init_value=self.logit_scale_init_value,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
@@ -189,7 +206,14 @@ class BridgeTowerModelTester:
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_10, "BridgeTower is only available in torch v1.10+")
|
||||
class BridgeTowerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(BridgeTowerModel, BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM) if is_torch_available() else ()
|
||||
(
|
||||
BridgeTowerModel,
|
||||
BridgeTowerForImageAndTextRetrieval,
|
||||
BridgeTowerForMaskedLM,
|
||||
BridgeTowerForContrastiveLearning,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = {"feature-extraction": BridgeTowerModel} if is_torch_available() else {}
|
||||
|
||||
@@ -347,6 +371,29 @@ class BridgeTowerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
if self.has_attentions:
|
||||
self.assertIsNotNone(attentions.grad)
|
||||
|
||||
# override as the `logit_scale` parameter initilization is different for BRIDGE TOWER
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
if name == "logit_scale":
|
||||
self.assertAlmostEqual(
|
||||
param.data.item(),
|
||||
config.logit_scale_init_value,
|
||||
delta=1e-3,
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
else:
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
@unittest.skip(reason="""Bridge Tower does not have input/output embeddings. So this test is not applicable.""")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
@@ -429,12 +476,31 @@ class BridgeTowerModelIntegrationTest(unittest.TestCase):
|
||||
outputs = model(**inputs)
|
||||
self.assertAlmostEqual(outputs.loss.item(), 5.7373, places=4)
|
||||
|
||||
@slow
|
||||
def test_constrastive_learning(self):
|
||||
model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc").to(
|
||||
torch_device
|
||||
)
|
||||
model.eval()
|
||||
processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
||||
image = prepare_img()
|
||||
text = "a bunch of cats laying on a tower."
|
||||
inputs = processor(image, text, return_tensors="pt").to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, output_hidden_states=True)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size([1, 3, 512])
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
|
||||
@require_torch
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_10, "BridgeTower is only available in torch v1.10+")
|
||||
class BridgeTowerModelTrainingTest(unittest.TestCase):
|
||||
all_training_supported_model_classes = (
|
||||
(BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM) if is_torch_available() else ()
|
||||
(BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM, BridgeTowerForContrastiveLearning)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
@@ -445,7 +511,7 @@ class BridgeTowerModelTrainingTest(unittest.TestCase):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if model_class == BridgeTowerForMaskedLM:
|
||||
inputs_dict["labels"] = inputs_dict["input_ids"]
|
||||
elif model_class == BridgeTowerForImageAndTextRetrieval:
|
||||
elif model_class == BridgeTowerForImageAndTextRetrieval or model_class == BridgeTowerForContrastiveLearning:
|
||||
inputs_dict["labels"] = ids_tensor([1], 2)
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
Reference in New Issue
Block a user