Add loss for BridgeTowerForMaskedLM and BridgeTowerForImageAndTextRetrieval (#21684)
* Add loss for BridgeTowerForMaskedLM and BridgeTowerForImageAndTextRetrieval * minor fix return_dict * implement test for loss computation --------- Co-authored-by: Tiep Le <97980157+tileintel@users.noreply.github.com> Co-authored-by: Tiep Le <tiep.le@intel.com>
This commit is contained in:
committed by
GitHub
parent
7f4f8b97d0
commit
4cb5ffa93d
@@ -22,6 +22,7 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN, QuickGELUActivation
|
||||
from ...modeling_outputs import (
|
||||
@@ -1535,8 +1536,10 @@ class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel):
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
) -> Union[MaskedLMOutput, Tuple[torch.FloatTensor]]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels are currently not supported.
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
||||
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
||||
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
@@ -1580,11 +1583,17 @@ class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel):
|
||||
)
|
||||
|
||||
mlm_logits = self.mlm_score(outputs.text_features if return_dict else outputs[0])
|
||||
masked_lm_loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
||||
masked_lm_loss = loss_fct(mlm_logits.view(-1, self.config.text_config.vocab_size), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
return tuple(mlm_logits)
|
||||
output = tuple(mlm_logits)
|
||||
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
||||
|
||||
return MaskedLMOutput(
|
||||
loss=masked_lm_loss,
|
||||
logits=mlm_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
@@ -1627,8 +1636,9 @@ class BridgeTowerForImageAndTextRetrieval(BridgeTowerPreTrainedModel):
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
) -> Union[SequenceClassifierOutput, Tuple[torch.FloatTensor]]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels are currently not supported.
|
||||
labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
|
||||
Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match.
|
||||
The pairs with 0 will be skipped for calculation.
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
@@ -1673,11 +1683,17 @@ class BridgeTowerForImageAndTextRetrieval(BridgeTowerPreTrainedModel):
|
||||
|
||||
logits = self.itm_score(pooler_output)
|
||||
|
||||
itm_loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
itm_loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(logits)
|
||||
output = tuple(logits)
|
||||
return ((itm_loss,) + output) if itm_loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=None,
|
||||
loss=itm_loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
|
||||
@@ -392,6 +392,13 @@ class BridgeTowerModelIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
self.assertTrue(outputs.logits[0, 1].item() > outputs.logits[0, 0].item())
|
||||
|
||||
# verify loss
|
||||
inputs["labels"] = torch.ones(1, dtype=torch.long, device=torch_device)
|
||||
inputs = inputs.to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
self.assertAlmostEqual(outputs.loss.item(), 0.5108, places=4)
|
||||
|
||||
@slow
|
||||
def test_masked_language_modeling(self):
|
||||
model = BridgeTowerForMaskedLM.from_pretrained("BridgeTower/bridgetower-base-itm-mlm").to(torch_device)
|
||||
@@ -412,3 +419,62 @@ class BridgeTowerModelIntegrationTest(unittest.TestCase):
|
||||
# verify predicted word
|
||||
predicted_id = outputs.logits.argmax(dim=-1).squeeze(0).tolist()[4]
|
||||
self.assertTrue(processor.decode([predicted_id]) == " cats")
|
||||
|
||||
# verify loss
|
||||
inputs["labels"] = inputs["input_ids"].clone()
|
||||
inputs = inputs.to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
self.assertAlmostEqual(outputs.loss.item(), 5.7373, places=4)
|
||||
|
||||
|
||||
@require_torch
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_10, "BridgeTower is only available in torch v1.10+")
|
||||
class BridgeTowerModelTrainingTest(unittest.TestCase):
|
||||
all_training_supported_model_classes = (
|
||||
(BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM) if is_torch_available() else ()
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BridgeTowerModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=50265)
|
||||
|
||||
def _prepare_inputs_for_training(self, model_class):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if model_class == BridgeTowerForMaskedLM:
|
||||
inputs_dict["labels"] = inputs_dict["input_ids"]
|
||||
elif model_class == BridgeTowerForImageAndTextRetrieval:
|
||||
inputs_dict["labels"] = ids_tensor([1], 2)
|
||||
return config, inputs_dict
|
||||
|
||||
def _get_non_used_layer_names(self, model_class):
|
||||
non_used_layer_names = ["text_model.pooler"]
|
||||
if model_class == BridgeTowerForMaskedLM:
|
||||
non_used_layer_names = non_used_layer_names + [
|
||||
"cross_modal_image_layers.5",
|
||||
"cross_modal_image_pooler",
|
||||
"cross_modal_text_pooler",
|
||||
]
|
||||
return non_used_layer_names
|
||||
|
||||
def _is_layer_used(self, model_class, layer_name):
|
||||
non_used_layer_names = self._get_non_used_layer_names(model_class)
|
||||
for non_used_layer_name in non_used_layer_names:
|
||||
if non_used_layer_name in layer_name:
|
||||
return False
|
||||
return True
|
||||
|
||||
def test_training(self):
|
||||
for model_class in self.all_training_supported_model_classes:
|
||||
config, inputs_dict = self._prepare_inputs_for_training(model_class)
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
loss = model(**inputs_dict).loss
|
||||
loss.backward()
|
||||
|
||||
# verify the gradients of used layers' weight are not None
|
||||
for name, param in model.named_parameters():
|
||||
if self._is_layer_used(model_class, name):
|
||||
self.assertIsNotNone(param.grad, f"Gradients should not be None - got {param.grad} for {name}")
|
||||
|
||||
Reference in New Issue
Block a user