Added patch to remaining models
This commit is contained in:
@@ -337,12 +337,14 @@ class BertAttention(nn.Module):
|
|||||||
super(BertAttention, self).__init__()
|
super(BertAttention, self).__init__()
|
||||||
self.self = BertSelfAttention(config)
|
self.self = BertSelfAttention(config)
|
||||||
self.output = BertSelfOutput(config)
|
self.output = BertSelfOutput(config)
|
||||||
|
self.pruned_heads = []
|
||||||
|
|
||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
if len(heads) == 0:
|
if len(heads) == 0:
|
||||||
return
|
return
|
||||||
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
|
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
|
||||||
for head in heads:
|
for head in heads:
|
||||||
|
head -= len(list(filter(lambda h: h < head, self.pruned_heads)))
|
||||||
mask[head] = 0
|
mask[head] = 0
|
||||||
mask = mask.view(-1).contiguous().eq(1)
|
mask = mask.view(-1).contiguous().eq(1)
|
||||||
index = torch.arange(len(mask))[mask].long()
|
index = torch.arange(len(mask))[mask].long()
|
||||||
@@ -354,6 +356,7 @@ class BertAttention(nn.Module):
|
|||||||
# Update hyper params
|
# Update hyper params
|
||||||
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
||||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||||
|
self.pruned_heads.extend(heads)
|
||||||
|
|
||||||
def forward(self, input_tensor, attention_mask, head_mask=None):
|
def forward(self, input_tensor, attention_mask, head_mask=None):
|
||||||
self_outputs = self.self(input_tensor, attention_mask, head_mask)
|
self_outputs = self.self(input_tensor, attention_mask, head_mask)
|
||||||
|
|||||||
@@ -249,12 +249,14 @@ class Attention(nn.Module):
|
|||||||
self.c_proj = Conv1D(n_state, nx)
|
self.c_proj = Conv1D(n_state, nx)
|
||||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||||
|
self.pruned_heads = []
|
||||||
|
|
||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
if len(heads) == 0:
|
if len(heads) == 0:
|
||||||
return
|
return
|
||||||
mask = torch.ones(self.n_head, self.split_size // self.n_head)
|
mask = torch.ones(self.n_head, self.split_size // self.n_head)
|
||||||
for head in heads:
|
for head in heads:
|
||||||
|
head -= len(list(filter(lambda h: h < head, self.pruned_heads)))
|
||||||
mask[head] = 0
|
mask[head] = 0
|
||||||
mask = mask.view(-1).contiguous().eq(1)
|
mask = mask.view(-1).contiguous().eq(1)
|
||||||
index = torch.arange(len(mask))[mask].long()
|
index = torch.arange(len(mask))[mask].long()
|
||||||
@@ -265,6 +267,7 @@ class Attention(nn.Module):
|
|||||||
# Update hyper params
|
# Update hyper params
|
||||||
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
|
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
|
||||||
self.n_head = self.n_head - len(heads)
|
self.n_head = self.n_head - len(heads)
|
||||||
|
self.pruned_heads.extend(heads)
|
||||||
|
|
||||||
def _attn(self, q, k, v, head_mask=None):
|
def _attn(self, q, k, v, head_mask=None):
|
||||||
w = torch.matmul(q, k)
|
w = torch.matmul(q, k)
|
||||||
|
|||||||
@@ -271,6 +271,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
self.k_lin = nn.Linear(dim, dim)
|
self.k_lin = nn.Linear(dim, dim)
|
||||||
self.v_lin = nn.Linear(dim, dim)
|
self.v_lin = nn.Linear(dim, dim)
|
||||||
self.out_lin = nn.Linear(dim, dim)
|
self.out_lin = nn.Linear(dim, dim)
|
||||||
|
self.pruned_heads = []
|
||||||
|
|
||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
attention_head_size = self.dim // self.n_heads
|
attention_head_size = self.dim // self.n_heads
|
||||||
@@ -278,6 +279,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
return
|
return
|
||||||
mask = torch.ones(self.n_heads, attention_head_size)
|
mask = torch.ones(self.n_heads, attention_head_size)
|
||||||
for head in heads:
|
for head in heads:
|
||||||
|
head -= len(list(filter(lambda h: h < head, self.pruned_heads)))
|
||||||
mask[head] = 0
|
mask[head] = 0
|
||||||
mask = mask.view(-1).contiguous().eq(1)
|
mask = mask.view(-1).contiguous().eq(1)
|
||||||
index = torch.arange(len(mask))[mask].long()
|
index = torch.arange(len(mask))[mask].long()
|
||||||
@@ -289,6 +291,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
# Update hyper params
|
# Update hyper params
|
||||||
self.n_heads = self.n_heads - len(heads)
|
self.n_heads = self.n_heads - len(heads)
|
||||||
self.dim = attention_head_size * self.n_heads
|
self.dim = attention_head_size * self.n_heads
|
||||||
|
self.pruned_heads.extend(heads)
|
||||||
|
|
||||||
def forward(self, input, mask, kv=None, cache=None, head_mask=None):
|
def forward(self, input, mask, kv=None, cache=None, head_mask=None):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user