Fix missed head transpose
This commit is contained in:
@@ -284,7 +284,7 @@ class XLNetRelativeAttention(nn.Module):
|
|||||||
|
|
||||||
# Mask heads if we want to
|
# Mask heads if we want to
|
||||||
if head_mask is not None:
|
if head_mask is not None:
|
||||||
attn_prob = attn_prob * head_mask
|
attn_prob = attn_prob * torch.einsum('ijbn->bnij', head_mask)
|
||||||
|
|
||||||
# attention output
|
# attention output
|
||||||
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, v_head_h)
|
attn_vec = torch.einsum('bnij,jbnd->ibnd', attn_prob, v_head_h)
|
||||||
|
|||||||
Reference in New Issue
Block a user