From 4cb5ffa93d400636b6809563dc806b64a9b9550d Mon Sep 17 00:00:00 2001 From: Anahita Bhiwandiwalla Date: Tue, 28 Feb 2023 09:21:48 -0800 Subject: [PATCH] Add loss for BridgeTowerForMaskedLM and BridgeTowerForImageAndTextRetrieval (#21684) * Add loss for BridgeTowerForMaskedLM and BridgeTowerForImageAndTextRetrieval * minor fix return_dict * implement test for loss computation --------- Co-authored-by: Tiep Le <97980157+tileintel@users.noreply.github.com> Co-authored-by: Tiep Le --- .../bridgetower/modeling_bridgetower.py | 30 +++++++-- .../bridgetower/test_modeling_bridgetower.py | 66 +++++++++++++++++++ 2 files changed, 89 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 16733d6f66..1fbc85ad31 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -22,6 +22,7 @@ from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn +from torch.nn import CrossEntropyLoss from ...activations import ACT2FN, QuickGELUActivation from ...modeling_outputs import ( @@ -1535,8 +1536,10 @@ class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel): labels: Optional[torch.LongTensor] = None, ) -> Union[MaskedLMOutput, Tuple[torch.FloatTensor]]: r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels are currently not supported. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` Returns: Examples: @@ -1580,11 +1583,17 @@ class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel): ) mlm_logits = self.mlm_score(outputs.text_features if return_dict else outputs[0]) + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(mlm_logits.view(-1, self.config.text_config.vocab_size), labels.view(-1)) if not return_dict: - return tuple(mlm_logits) + output = tuple(mlm_logits) + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return MaskedLMOutput( + loss=masked_lm_loss, logits=mlm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, @@ -1627,8 +1636,9 @@ class BridgeTowerForImageAndTextRetrieval(BridgeTowerPreTrainedModel): labels: Optional[torch.LongTensor] = None, ) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]: r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels are currently not supported. + labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): + Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match. + The pairs with 0 will be skipped for calculation. Returns: Examples: @@ -1673,11 +1683,17 @@ class BridgeTowerForImageAndTextRetrieval(BridgeTowerPreTrainedModel): logits = self.itm_score(pooler_output) + itm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + itm_loss = loss_fct(logits, labels) + if not return_dict: - return tuple(logits) + output = tuple(logits) + return ((itm_loss,) + output) if itm_loss is not None else output return SequenceClassifierOutput( - loss=None, + loss=itm_loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, diff --git a/tests/models/bridgetower/test_modeling_bridgetower.py b/tests/models/bridgetower/test_modeling_bridgetower.py index 7d35959968..7405293a1c 100644 --- a/tests/models/bridgetower/test_modeling_bridgetower.py +++ b/tests/models/bridgetower/test_modeling_bridgetower.py @@ -392,6 +392,13 @@ class BridgeTowerModelIntegrationTest(unittest.TestCase): self.assertEqual(outputs.logits.shape, expected_shape) self.assertTrue(outputs.logits[0, 1].item() > outputs.logits[0, 0].item()) + # verify loss + inputs["labels"] = torch.ones(1, dtype=torch.long, device=torch_device) + inputs = inputs.to(torch_device) + with torch.no_grad(): + outputs = model(**inputs) + self.assertAlmostEqual(outputs.loss.item(), 0.5108, places=4) + @slow def test_masked_language_modeling(self): model = BridgeTowerForMaskedLM.from_pretrained("BridgeTower/bridgetower-base-itm-mlm").to(torch_device) @@ -412,3 +419,62 @@ class BridgeTowerModelIntegrationTest(unittest.TestCase): # verify predicted word predicted_id = outputs.logits.argmax(dim=-1).squeeze(0).tolist()[4] self.assertTrue(processor.decode([predicted_id]) == " cats") + + # verify loss + inputs["labels"] = inputs["input_ids"].clone() + inputs = inputs.to(torch_device) + with torch.no_grad(): + outputs = model(**inputs) + self.assertAlmostEqual(outputs.loss.item(), 5.7373, places=4) + + +@require_torch +@unittest.skipIf(not is_torch_greater_or_equal_than_1_10, "BridgeTower is only available in torch v1.10+") +class BridgeTowerModelTrainingTest(unittest.TestCase): + all_training_supported_model_classes = ( + (BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM) if is_torch_available() else () + ) + + def setUp(self): + self.model_tester = BridgeTowerModelTester(self) + self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=50265) + + def _prepare_inputs_for_training(self, model_class): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if model_class == BridgeTowerForMaskedLM: + inputs_dict["labels"] = inputs_dict["input_ids"] + elif model_class == BridgeTowerForImageAndTextRetrieval: + inputs_dict["labels"] = ids_tensor([1], 2) + return config, inputs_dict + + def _get_non_used_layer_names(self, model_class): + non_used_layer_names = ["text_model.pooler"] + if model_class == BridgeTowerForMaskedLM: + non_used_layer_names = non_used_layer_names + [ + "cross_modal_image_layers.5", + "cross_modal_image_pooler", + "cross_modal_text_pooler", + ] + return non_used_layer_names + + def _is_layer_used(self, model_class, layer_name): + non_used_layer_names = self._get_non_used_layer_names(model_class) + for non_used_layer_name in non_used_layer_names: + if non_used_layer_name in layer_name: + return False + return True + + def test_training(self): + for model_class in self.all_training_supported_model_classes: + config, inputs_dict = self._prepare_inputs_for_training(model_class) + model = model_class(config) + model.to(torch_device) + model.train() + + loss = model(**inputs_dict).loss + loss.backward() + + # verify the gradients of used layers' weight are not None + for name, param in model.named_parameters(): + if self._is_layer_used(model_class, name): + self.assertIsNotNone(param.grad, f"Gradients should not be None - got {param.grad} for {name}")