From 9ffda216ec65126f3fd834da072445ec146f1d77 Mon Sep 17 00:00:00 2001 From: Simon Layton Date: Thu, 3 Oct 2019 09:23:16 -0400 Subject: [PATCH] Fix missed head transpose --- transformers/modeling_xlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformers/modeling_xlnet.py b/transformers/modeling_xlnet.py index 15a0ed0add..e72f316b38 100644 --- a/transformers/modeling_xlnet.py +++ b/transformers/modeling_xlnet.py @@ -284,7 +284,7 @@ class XLNetRelativeAttention(nn.Module): # Mask heads if we want to if head_mask is not None: - attn_prob = attn_prob * head_mask + attn_prob = attn_prob * torch.einsum('ijbn->bnij', head_mask) # attention output attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, v_head_h)