Merge pull request #495 from SudoSharma/patch-2

Fix gradient overflow issue during attention mask
This commit is contained in:
Thomas Wolf
2019-04-17 11:10:36 +02:00
committed by GitHub

View File

@@ -218,7 +218,7 @@ class Attention(nn.Module):
w = w / math.sqrt(v.size(-1))
nd, ns = w.size(-2), w.size(-1)
b = self.bias[:, :, ns-nd:ns, :ns]
w = w * b - 1e10 * (1 - b)
w = w * b - 1e4 * (1 - b)
w = nn.Softmax(dim=-1)(w)
return torch.matmul(w, v)