🚨 FLAVA: Remove double softmax (#31322)

Remove double softmax
This commit is contained in:
amyeroberts
2024-06-10 15:01:27 +01:00
committed by GitHub
parent 8fff07ded0
commit a4e1a1d028
2 changed files with 1 additions and 3 deletions

View File

@@ -472,8 +472,6 @@ class FlavaSelfAttention(nn.Module):
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1) attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.

View File

@@ -1285,7 +1285,7 @@ class FlavaModelIntegrationTest(unittest.TestCase):
# verify the embeddings # verify the embeddings
self.assertAlmostEqual(outputs.image_embeddings.sum().item(), -1352.53540, places=4) self.assertAlmostEqual(outputs.image_embeddings.sum().item(), -1352.53540, places=4)
self.assertAlmostEqual(outputs.text_embeddings.sum().item(), -198.98225, places=4) self.assertAlmostEqual(outputs.text_embeddings.sum().item(), -198.98225, places=4)
self.assertAlmostEqual(outputs.multimodal_embeddings.sum().item(), -4030.4602050, places=4) self.assertAlmostEqual(outputs.multimodal_embeddings.sum().item(), -4030.4604492, places=4)
@require_vision @require_vision