Added patch to remaining models
This commit is contained in:
@@ -271,6 +271,7 @@ class MultiHeadAttention(nn.Module):
|
||||
self.k_lin = nn.Linear(dim, dim)
|
||||
self.v_lin = nn.Linear(dim, dim)
|
||||
self.out_lin = nn.Linear(dim, dim)
|
||||
self.pruned_heads = []
|
||||
|
||||
def prune_heads(self, heads):
|
||||
attention_head_size = self.dim // self.n_heads
|
||||
@@ -278,6 +279,7 @@ class MultiHeadAttention(nn.Module):
|
||||
return
|
||||
mask = torch.ones(self.n_heads, attention_head_size)
|
||||
for head in heads:
|
||||
head -= len(list(filter(lambda h: h < head, self.pruned_heads)))
|
||||
mask[head] = 0
|
||||
mask = mask.view(-1).contiguous().eq(1)
|
||||
index = torch.arange(len(mask))[mask].long()
|
||||
@@ -289,6 +291,7 @@ class MultiHeadAttention(nn.Module):
|
||||
# Update hyper params
|
||||
self.n_heads = self.n_heads - len(heads)
|
||||
self.dim = attention_head_size * self.n_heads
|
||||
self.pruned_heads.extend(heads)
|
||||
|
||||
def forward(self, input, mask, kv=None, cache=None, head_mask=None):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user