Fix missed head transpose

This commit is contained in:
Simon Layton
2019-10-03 09:23:16 -04:00
parent d51b589404
commit 9ffda216ec

View File

@@ -284,7 +284,7 @@ class XLNetRelativeAttention(nn.Module):
# Mask heads if we want to
if head_mask is not None:
attn_prob = attn_prob * head_mask
attn_prob = attn_prob * torch.einsum('ijbn->bnij', head_mask)
# attention output
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, v_head_h)