Added patch to remaining models

This commit is contained in:
LysandreJik
2019-08-29 17:20:11 -04:00
parent 0cd283522a
commit 0c8e823b03
3 changed files with 9 additions and 0 deletions

View File

@@ -337,12 +337,14 @@ class BertAttention(nn.Module):
super(BertAttention, self).__init__()
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
self.pruned_heads = []
def prune_heads(self, heads):
if len(heads) == 0:
return
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
for head in heads:
head -= len(list(filter(lambda h: h < head, self.pruned_heads)))
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
@@ -354,6 +356,7 @@ class BertAttention(nn.Module):
# Update hyper params
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads.extend(heads)
def forward(self, input_tensor, attention_mask, head_mask=None):
self_outputs = self.self(input_tensor, attention_mask, head_mask)