Flava multimodal add attention mask (#29446)

* flava multimodal add attn mask

* make style

* check mask is not None
This commit is contained in:
Raushan Turganbay
2024-03-07 16:45:47 +05:00
committed by GitHub
parent 9288e759ad
commit 923733c22b
2 changed files with 19 additions and 9 deletions

View File

@@ -1287,9 +1287,9 @@ class FlavaModelIntegrationTest(unittest.TestCase):
outputs = model(**inputs, return_dict=True)
# verify the embeddings
self.assertAlmostEqual(outputs.image_embeddings.sum().item(), -1352.53540, places=4)
self.assertAlmostEqual(outputs.image_embeddings.sum().item(), -1352.54943, places=4)
self.assertAlmostEqual(outputs.text_embeddings.sum().item(), -198.98225, places=4)
self.assertAlmostEqual(outputs.multimodal_embeddings.sum().item(), -3988.51367, places=4)
self.assertAlmostEqual(outputs.multimodal_embeddings.sum().item(), -4030.466552, places=4)
@require_vision
@@ -1339,9 +1339,9 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase):
expected_logits = torch.tensor([[16.1291, 8.4033], [16.1291, 8.4033]], device=torch_device)
self.assertTrue(torch.allclose(outputs.contrastive_logits_per_image, expected_logits, atol=1e-3))
self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 1.75533199, places=4)
self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 7.0290069, places=4)
self.assertAlmostEqual(outputs.loss.item(), 11.0626, places=4)
self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 2.0736470, places=4)
self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 7.025580, places=4)
self.assertAlmostEqual(outputs.loss.item(), 11.37761, places=4)
@slow
def test_inference_with_itm_labels(self):
@@ -1390,6 +1390,6 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase):
expected_logits = torch.tensor([[16.1291, 8.4033], [16.1291, 8.4033]], device=torch_device)
self.assertTrue(torch.allclose(outputs.contrastive_logits_per_image, expected_logits, atol=1e-3))
self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 1.75533199, places=4)
self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 6.89590501, places=4)
self.assertAlmostEqual(outputs.loss.item(), 9.1995, places=4)
self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 2.0736470, places=4)
self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 6.8962264, places=4)
self.assertAlmostEqual(outputs.loss.item(), 9.6090, places=4)