cleanup deltas
This commit is contained in:
@@ -640,9 +640,8 @@ class SelfAttention(nn.Module):
|
||||
reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool)
|
||||
attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
attn_weights_float = F.softmax(attn_weights, dim=-1)
|
||||
attn_probs = F.dropout(attn_weights_float, p=self.dropout, training=self.training,)
|
||||
attn_weights = attn_weights_float.type_as(attn_weights)
|
||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training,)
|
||||
|
||||
assert v is not None
|
||||
attn_output = torch.bmm(attn_probs, v)
|
||||
|
||||
Reference in New Issue
Block a user