From 923733c22bf4d3cc6661c8cd3b730b275e9a938e Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Thu, 7 Mar 2024 16:45:47 +0500 Subject: [PATCH] Flava multimodal add attention mask (#29446) * flava multimodal add attn mask * make style * check mask is not None --- src/transformers/models/flava/modeling_flava.py | 12 +++++++++++- tests/models/flava/test_modeling_flava.py | 16 ++++++++-------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index f96e4292a1..0e5cfe1b68 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -1415,8 +1415,18 @@ class FlavaModel(FlavaPreTrainedModel): multimodal_embeddings = None multimodal_output = None if image_mm_projection is not None and text_mm_projection is not None and not skip_multimodal_encoder: + if attention_mask is not None: + batch_size, seq_len, _ = image_mm_projection.shape + if self.multimodal_model.use_cls_token: + seq_len += 1 + attention_mask_image = torch.ones(batch_size, seq_len, device=image_mm_projection.device) + attention_multimodal = torch.cat([attention_mask_image, attention_mask], dim=1) + else: + attention_multimodal = None multimodal_input = torch.cat([image_mm_projection, text_mm_projection], dim=1) - multimodal_output = self.multimodal_model(multimodal_input, return_dict=return_dict) + multimodal_output = self.multimodal_model( + multimodal_input, attention_mask=attention_multimodal, return_dict=return_dict + ) multimodal_embeddings = multimodal_output[0] if not return_dict: diff --git a/tests/models/flava/test_modeling_flava.py b/tests/models/flava/test_modeling_flava.py index 48a070d9fe..b17a6f7b54 100644 --- a/tests/models/flava/test_modeling_flava.py +++ b/tests/models/flava/test_modeling_flava.py @@ -1287,9 +1287,9 @@ class FlavaModelIntegrationTest(unittest.TestCase): outputs = model(**inputs, return_dict=True) # verify the embeddings - self.assertAlmostEqual(outputs.image_embeddings.sum().item(), -1352.53540, places=4) + self.assertAlmostEqual(outputs.image_embeddings.sum().item(), -1352.54943, places=4) self.assertAlmostEqual(outputs.text_embeddings.sum().item(), -198.98225, places=4) - self.assertAlmostEqual(outputs.multimodal_embeddings.sum().item(), -3988.51367, places=4) + self.assertAlmostEqual(outputs.multimodal_embeddings.sum().item(), -4030.466552, places=4) @require_vision @@ -1339,9 +1339,9 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase): expected_logits = torch.tensor([[16.1291, 8.4033], [16.1291, 8.4033]], device=torch_device) self.assertTrue(torch.allclose(outputs.contrastive_logits_per_image, expected_logits, atol=1e-3)) - self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 1.75533199, places=4) - self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 7.0290069, places=4) - self.assertAlmostEqual(outputs.loss.item(), 11.0626, places=4) + self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 2.0736470, places=4) + self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 7.025580, places=4) + self.assertAlmostEqual(outputs.loss.item(), 11.37761, places=4) @slow def test_inference_with_itm_labels(self): @@ -1390,6 +1390,6 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase): expected_logits = torch.tensor([[16.1291, 8.4033], [16.1291, 8.4033]], device=torch_device) self.assertTrue(torch.allclose(outputs.contrastive_logits_per_image, expected_logits, atol=1e-3)) - self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 1.75533199, places=4) - self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 6.89590501, places=4) - self.assertAlmostEqual(outputs.loss.item(), 9.1995, places=4) + self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 2.0736470, places=4) + self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 6.8962264, places=4) + self.assertAlmostEqual(outputs.loss.item(), 9.6090, places=4)