Add loss for BridgeTowerForMaskedLM and BridgeTowerForImageAndTextRetrieval (#21684)
* Add loss for BridgeTowerForMaskedLM and BridgeTowerForImageAndTextRetrieval * minor fix return_dict * implement test for loss computation --------- Co-authored-by: Tiep Le <97980157+tileintel@users.noreply.github.com> Co-authored-by: Tiep Le <tiep.le@intel.com>
This commit is contained in:
committed by
GitHub
parent
7f4f8b97d0
commit
4cb5ffa93d
@@ -392,6 +392,13 @@ class BridgeTowerModelIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
self.assertTrue(outputs.logits[0, 1].item() > outputs.logits[0, 0].item())
|
||||
|
||||
# verify loss
|
||||
inputs["labels"] = torch.ones(1, dtype=torch.long, device=torch_device)
|
||||
inputs = inputs.to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
self.assertAlmostEqual(outputs.loss.item(), 0.5108, places=4)
|
||||
|
||||
@slow
|
||||
def test_masked_language_modeling(self):
|
||||
model = BridgeTowerForMaskedLM.from_pretrained("BridgeTower/bridgetower-base-itm-mlm").to(torch_device)
|
||||
@@ -412,3 +419,62 @@ class BridgeTowerModelIntegrationTest(unittest.TestCase):
|
||||
# verify predicted word
|
||||
predicted_id = outputs.logits.argmax(dim=-1).squeeze(0).tolist()[4]
|
||||
self.assertTrue(processor.decode([predicted_id]) == " cats")
|
||||
|
||||
# verify loss
|
||||
inputs["labels"] = inputs["input_ids"].clone()
|
||||
inputs = inputs.to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
self.assertAlmostEqual(outputs.loss.item(), 5.7373, places=4)
|
||||
|
||||
|
||||
@require_torch
|
||||
@unittest.skipIf(not is_torch_greater_or_equal_than_1_10, "BridgeTower is only available in torch v1.10+")
|
||||
class BridgeTowerModelTrainingTest(unittest.TestCase):
|
||||
all_training_supported_model_classes = (
|
||||
(BridgeTowerForImageAndTextRetrieval, BridgeTowerForMaskedLM) if is_torch_available() else ()
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BridgeTowerModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=50265)
|
||||
|
||||
def _prepare_inputs_for_training(self, model_class):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if model_class == BridgeTowerForMaskedLM:
|
||||
inputs_dict["labels"] = inputs_dict["input_ids"]
|
||||
elif model_class == BridgeTowerForImageAndTextRetrieval:
|
||||
inputs_dict["labels"] = ids_tensor([1], 2)
|
||||
return config, inputs_dict
|
||||
|
||||
def _get_non_used_layer_names(self, model_class):
|
||||
non_used_layer_names = ["text_model.pooler"]
|
||||
if model_class == BridgeTowerForMaskedLM:
|
||||
non_used_layer_names = non_used_layer_names + [
|
||||
"cross_modal_image_layers.5",
|
||||
"cross_modal_image_pooler",
|
||||
"cross_modal_text_pooler",
|
||||
]
|
||||
return non_used_layer_names
|
||||
|
||||
def _is_layer_used(self, model_class, layer_name):
|
||||
non_used_layer_names = self._get_non_used_layer_names(model_class)
|
||||
for non_used_layer_name in non_used_layer_names:
|
||||
if non_used_layer_name in layer_name:
|
||||
return False
|
||||
return True
|
||||
|
||||
def test_training(self):
|
||||
for model_class in self.all_training_supported_model_classes:
|
||||
config, inputs_dict = self._prepare_inputs_for_training(model_class)
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
loss = model(**inputs_dict).loss
|
||||
loss.backward()
|
||||
|
||||
# verify the gradients of used layers' weight are not None
|
||||
for name, param in model.named_parameters():
|
||||
if self._is_layer_used(model_class, name):
|
||||
self.assertIsNotNone(param.grad, f"Gradients should not be None - got {param.grad} for {name}")
|
||||
|
||||
Reference in New Issue
Block a user