Update BridgeTowerForContrastiveLearning (#22145)

* Use return_loss for BridgeTowerForContrastiveLearning, add example

* fix tests

* Update example in BridgeTowerForContrastiveLearning

* Update test_modeling_bridgetower.py

* update model output format

* minor update

* Update src/transformers/models/bridgetower/modeling_bridgetower.py

* make style

---------

Co-authored-by: Tiep Le <97980157+tileintel@users.noreply.github.com>
Co-authored-by: Tiep Le <tiep.le@intel.com>
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Anahita Bhiwandiwalla
2023-03-15 12:54:38 -07:00
committed by GitHub
parent 42ad693b7b
commit 16121bae5c
2 changed files with 45 additions and 29 deletions

View File

@@ -94,7 +94,7 @@ class BridgeTowerModelTester:
self.num_hidden_layers = num_hidden_layers
self.tie_word_embeddings = tie_word_embeddings
self.init_layernorm_from_vision_encoder = init_layernorm_from_vision_encoder
self.vocab_size = 50265
self.vocab_size = 99
self.num_channels = 3
self.seq_length = 4
self.num_image_features = 325
@@ -188,7 +188,7 @@ class BridgeTowerModelTester:
result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, pixel_mask=pixel_mask)
result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, 50265))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
@@ -231,7 +231,7 @@ class BridgeTowerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
def setUp(self):
self.model_tester = BridgeTowerModelTester(self)
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=50265)
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=99)
def test_config(self):
self.config_tester.run_common_tests()
@@ -486,9 +486,9 @@ class BridgeTowerModelIntegrationTest(unittest.TestCase):
processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
image = prepare_img()
text = "a bunch of cats laying on a tower."
inputs = processor(image, text, return_tensors="pt").to(torch_device)
inputs = processor(image, text, padding=True, return_tensors="pt").to(torch_device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
outputs = model(**inputs, output_hidden_states=True, return_loss=True)
# verify the logits
expected_shape = torch.Size([1, 3, 512])
@@ -507,14 +507,16 @@ class BridgeTowerModelTrainingTest(unittest.TestCase):
def setUp(self):
self.model_tester = BridgeTowerModelTester(self)
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=50265)
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=99)
def _prepare_inputs_for_training(self, model_class):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if model_class == BridgeTowerForMaskedLM:
inputs_dict["labels"] = inputs_dict["input_ids"]
elif model_class == BridgeTowerForImageAndTextRetrieval or model_class == BridgeTowerForContrastiveLearning:
elif model_class == BridgeTowerForImageAndTextRetrieval:
inputs_dict["labels"] = ids_tensor([1], 2)
elif model_class == BridgeTowerForContrastiveLearning:
inputs_dict["return_loss"] = True
return config, inputs_dict
def _get_non_used_layer_names(self, model_class):