Update BridgeTowerForContrastiveLearning (#22145)
* Use return_loss for BridgeTowerForContrastiveLearning, add example * fix tests * Update example in BridgeTowerForContrastiveLearning * Update test_modeling_bridgetower.py * update model output format * minor update * Update src/transformers/models/bridgetower/modeling_bridgetower.py * make style --------- Co-authored-by: Tiep Le <97980157+tileintel@users.noreply.github.com> Co-authored-by: Tiep Le <tiep.le@intel.com> Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
42ad693b7b
commit
16121bae5c
@@ -94,7 +94,7 @@ class BridgeTowerModelTester:
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.init_layernorm_from_vision_encoder = init_layernorm_from_vision_encoder
|
||||
self.vocab_size = 50265
|
||||
self.vocab_size = 99
|
||||
self.num_channels = 3
|
||||
self.seq_length = 4
|
||||
self.num_image_features = 325
|
||||
@@ -188,7 +188,7 @@ class BridgeTowerModelTester:
|
||||
result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, pixel_mask=pixel_mask)
|
||||
result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
|
||||
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, 50265))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
@@ -231,7 +231,7 @@ class BridgeTowerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BridgeTowerModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=50265)
|
||||
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=99)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
@@ -486,9 +486,9 @@ class BridgeTowerModelIntegrationTest(unittest.TestCase):
|
||||
processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
||||
image = prepare_img()
|
||||
text = "a bunch of cats laying on a tower."
|
||||
inputs = processor(image, text, return_tensors="pt").to(torch_device)
|
||||
inputs = processor(image, text, padding=True, return_tensors="pt").to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, output_hidden_states=True)
|
||||
outputs = model(**inputs, output_hidden_states=True, return_loss=True)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size([1, 3, 512])
|
||||
@@ -507,14 +507,16 @@ class BridgeTowerModelTrainingTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BridgeTowerModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=50265)
|
||||
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=99)
|
||||
|
||||
def _prepare_inputs_for_training(self, model_class):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if model_class == BridgeTowerForMaskedLM:
|
||||
inputs_dict["labels"] = inputs_dict["input_ids"]
|
||||
elif model_class == BridgeTowerForImageAndTextRetrieval or model_class == BridgeTowerForContrastiveLearning:
|
||||
elif model_class == BridgeTowerForImageAndTextRetrieval:
|
||||
inputs_dict["labels"] = ids_tensor([1], 2)
|
||||
elif model_class == BridgeTowerForContrastiveLearning:
|
||||
inputs_dict["return_loss"] = True
|
||||
return config, inputs_dict
|
||||
|
||||
def _get_non_used_layer_names(self, model_class):
|
||||
|
||||
Reference in New Issue
Block a user