From a4e1a1d02894b8d801f5d0182e1979b55daaeaa4 Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Mon, 10 Jun 2024 15:01:27 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=20FLAVA:=20Remove=20double=20softm?= =?UTF-8?q?ax=20(#31322)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove double softmax --- src/transformers/models/flava/modeling_flava.py | 2 -- tests/models/flava/test_modeling_flava.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 5acbad05c3..dbc4e51703 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -472,8 +472,6 @@ class FlavaSelfAttention(nn.Module): # Normalize the attention scores to probabilities. attention_probs = nn.functional.softmax(attention_scores, dim=-1) - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. diff --git a/tests/models/flava/test_modeling_flava.py b/tests/models/flava/test_modeling_flava.py index 7e067115e7..388e2f041f 100644 --- a/tests/models/flava/test_modeling_flava.py +++ b/tests/models/flava/test_modeling_flava.py @@ -1285,7 +1285,7 @@ class FlavaModelIntegrationTest(unittest.TestCase): # verify the embeddings self.assertAlmostEqual(outputs.image_embeddings.sum().item(), -1352.53540, places=4) self.assertAlmostEqual(outputs.text_embeddings.sum().item(), -198.98225, places=4) - self.assertAlmostEqual(outputs.multimodal_embeddings.sum().item(), -4030.4602050, places=4) + self.assertAlmostEqual(outputs.multimodal_embeddings.sum().item(), -4030.4604492, places=4) @require_vision