Add loss for BridgeTowerForMaskedLM and BridgeTowerForImageAndTextRetrieval (#21684)
* Add loss for BridgeTowerForMaskedLM and BridgeTowerForImageAndTextRetrieval * minor fix return_dict * implement test for loss computation --------- Co-authored-by: Tiep Le <97980157+tileintel@users.noreply.github.com> Co-authored-by: Tiep Le <tiep.le@intel.com>
This commit is contained in:
committed by
GitHub
parent
7f4f8b97d0
commit
4cb5ffa93d
@@ -22,6 +22,7 @@ from typing import List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from ...activations import ACT2FN, QuickGELUActivation
|
from ...activations import ACT2FN, QuickGELUActivation
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
@@ -1535,8 +1536,10 @@ class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel):
|
|||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[MaskedLMOutput, Tuple[torch.FloatTensor]]:
|
) -> Union[MaskedLMOutput, Tuple[torch.FloatTensor]]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Labels are currently not supported.
|
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
||||||
|
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
||||||
|
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
@@ -1580,11 +1583,17 @@ class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
mlm_logits = self.mlm_score(outputs.text_features if return_dict else outputs[0])
|
mlm_logits = self.mlm_score(outputs.text_features if return_dict else outputs[0])
|
||||||
|
masked_lm_loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
||||||
|
masked_lm_loss = loss_fct(mlm_logits.view(-1, self.config.text_config.vocab_size), labels.view(-1))
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(mlm_logits)
|
output = tuple(mlm_logits)
|
||||||
|
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
||||||
|
|
||||||
return MaskedLMOutput(
|
return MaskedLMOutput(
|
||||||
|
loss=masked_lm_loss,
|
||||||
logits=mlm_logits,
|
logits=mlm_logits,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
@@ -1627,8 +1636,9 @@ class BridgeTowerForImageAndTextRetrieval(BridgeTowerPreTrainedModel):
|
|||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]:
|
) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
|
||||||
Labels are currently not supported.
|
Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match.
|
||||||
|
The pairs with 0 will be skipped for calculation.
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
@@ -1673,11 +1683,17 @@ class BridgeTowerForImageAndTextRetrieval(BridgeTowerPreTrainedModel):
|
|||||||
|
|
||||||
logits = self.itm_score(pooler_output)
|
logits = self.itm_score(pooler_output)
|
||||||
|
|
||||||
|
itm_loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
itm_loss = loss_fct(logits, labels)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(logits)
|
output = tuple(logits)
|
||||||
|
return ((itm_loss,) + output) if itm_loss is not None else output
|
||||||
|
|
||||||
return SequenceClassifierOutput(
|
return SequenceClassifierOutput(
|
||||||
loss=None,
|
loss=itm_loss,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
|
|||||||
@@ -392,6 +392,13 @@ class BridgeTowerModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||||
self.assertTrue(outputs.logits[0, 1].item() > outputs.logits[0, 0].item())
|
self.assertTrue(outputs.logits[0, 1].item() > outputs.logits[0, 0].item())
|
||||||
|
|
||||||
|
# verify loss
|
||||||
|
inputs["labels"] = torch.ones(1, dtype=torch.long, device=torch_device)
|
||||||
|
inputs = inputs.to(torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
self.assertAlmostEqual(outputs.loss.item(), 0.5108, places=4)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_masked_language_modeling(self):
|
def test_masked_language_modeling(self):
|
||||||
model = BridgeTowerForMaskedLM.from_pretrained("BridgeTower/bridgetower-base-itm-mlm").to(torch_device)
|
model = BridgeTowerForMaskedLM.from_pretrained("BridgeTower/bridgetower-base-itm-mlm").to(torch_device)
|
||||||
@@ -412,3 +419,62 @@ class BridgeTowerModelIntegrationTest(unittest.TestCase):
|
|||||||
# verify predicted word
|
# verify predicted word
|
||||||
predicted_id = outputs.logits.argmax(dim=-1).squeeze(0).tolist()[4]
|
predicted_id = outputs.logits.argmax(dim=-1).squeeze(0).tolist()[4]
|
||||||
self.assertTrue(processor.decode([predicted_id]) == " cats")
|
self.assertTrue(processor.decode([predicted_id]) == " cats")
|
||||||
|
|
||||||
|
# verify loss
|
||||||
|
inputs["labels"] = inputs["input_ids"].clone()
|
||||||
|
inputs = inputs.to(torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
self.assertAlmostEqual(outputs.loss.item(), 5.7373, places=4)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@unittest.skipIf(not is_torch_greater_or_equal_than_1_10, "BridgeTower is only available in torch v1.10+")
|
||||||
|
class BridgeTowerModelTrainingTest(unittest.TestCase):
|
||||||
|
all_training_supported_model_classes = (
|
||||||
|
(BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM) if is_torch_available() else ()
|
||||||
|
)
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = BridgeTowerModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=50265)
|
||||||
|
|
||||||
|
def _prepare_inputs_for_training(self, model_class):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
if model_class == BridgeTowerForMaskedLM:
|
||||||
|
inputs_dict["labels"] = inputs_dict["input_ids"]
|
||||||
|
elif model_class == BridgeTowerForImageAndTextRetrieval:
|
||||||
|
inputs_dict["labels"] = ids_tensor([1], 2)
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
def _get_non_used_layer_names(self, model_class):
|
||||||
|
non_used_layer_names = ["text_model.pooler"]
|
||||||
|
if model_class == BridgeTowerForMaskedLM:
|
||||||
|
non_used_layer_names = non_used_layer_names + [
|
||||||
|
"cross_modal_image_layers.5",
|
||||||
|
"cross_modal_image_pooler",
|
||||||
|
"cross_modal_text_pooler",
|
||||||
|
]
|
||||||
|
return non_used_layer_names
|
||||||
|
|
||||||
|
def _is_layer_used(self, model_class, layer_name):
|
||||||
|
non_used_layer_names = self._get_non_used_layer_names(model_class)
|
||||||
|
for non_used_layer_name in non_used_layer_names:
|
||||||
|
if non_used_layer_name in layer_name:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def test_training(self):
|
||||||
|
for model_class in self.all_training_supported_model_classes:
|
||||||
|
config, inputs_dict = self._prepare_inputs_for_training(model_class)
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
loss = model(**inputs_dict).loss
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# verify the gradients of used layers' weight are not None
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if self._is_layer_used(model_class, name):
|
||||||
|
self.assertIsNotNone(param.grad, f"Gradients should not be None - got {param.grad} for {name}")
|
||||||
|
|||||||
Reference in New Issue
Block a user