Merge pull request #1154 from ziliwang/master

fix: hard coding for max number
This commit is contained in:
Thomas Wolf
2019-08-30 23:23:08 +02:00
committed by GitHub

View File

@@ -418,7 +418,10 @@ class XLNetRelativeAttention(nn.Module):
attn_score = (ac + bd + ef) * self.scale
if attn_mask is not None:
# attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
attn_score = attn_score - 1e30 * attn_mask
if attn_mask.dtype == torch.float16:
attn_score = attn_score - 65500 * attn_mask
else:
attn_score = attn_score - 1e30 * attn_mask
# attention probability
attn_prob = F.softmax(attn_score, dim=1)