Merge pull request #495 from SudoSharma/patch-2
Fix gradient overflow issue during attention mask
This commit is contained in:
@@ -218,7 +218,7 @@ class Attention(nn.Module):
|
|||||||
w = w / math.sqrt(v.size(-1))
|
w = w / math.sqrt(v.size(-1))
|
||||||
nd, ns = w.size(-2), w.size(-1)
|
nd, ns = w.size(-2), w.size(-1)
|
||||||
b = self.bias[:, :, ns-nd:ns, :ns]
|
b = self.bias[:, :, ns-nd:ns, :ns]
|
||||||
w = w * b - 1e10 * (1 - b)
|
w = w * b - 1e4 * (1 - b)
|
||||||
|
|
||||||
w = nn.Softmax(dim=-1)(w)
|
w = nn.Softmax(dim=-1)(w)
|
||||||
return torch.matmul(w, v)
|
return torch.matmul(w, v)
|
||||||
|
|||||||
Reference in New Issue
Block a user