adding head pruning and tests

This commit is contained in:
thomwolf
2019-06-17 13:20:45 +02:00
parent 8415a38b23
commit 7220d47a1c
2 changed files with 45 additions and 4 deletions

View File

@@ -51,12 +51,11 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
BERT_CONFIG_NAME = 'bert_config.json'
TF_WEIGHTS_NAME = 'model.ckpt'
def prune_linear_layer(layer, index, dim=-1):
def prune_linear_layer(layer, index, dim=0):
""" Prune a linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True.
Used to remove heads.
"""
dim = (dim+100) % 2
index = index.to(layer.weight.device)
W = layer.weight.index_select(dim, index).clone().detach()
if layer.bias is not None:
@@ -394,7 +393,7 @@ class BertAttention(nn.Module):
self.output = BertSelfOutput(config)
def prune_heads(self, heads):
mask = torch.ones(self.self.n_heads, self.self.d_head)
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
for head in heads:
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
@@ -403,7 +402,7 @@ class BertAttention(nn.Module):
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=0)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads