adding head pruning and tests
This commit is contained in:
@@ -51,12 +51,11 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
BERT_CONFIG_NAME = 'bert_config.json'
|
||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||
|
||||
def prune_linear_layer(layer, index, dim=-1):
|
||||
def prune_linear_layer(layer, index, dim=0):
|
||||
""" Prune a linear layer (a model parameters) to keep only entries in index.
|
||||
Return the pruned layer as a new layer with requires_grad=True.
|
||||
Used to remove heads.
|
||||
"""
|
||||
dim = (dim+100) % 2
|
||||
index = index.to(layer.weight.device)
|
||||
W = layer.weight.index_select(dim, index).clone().detach()
|
||||
if layer.bias is not None:
|
||||
@@ -394,7 +393,7 @@ class BertAttention(nn.Module):
|
||||
self.output = BertSelfOutput(config)
|
||||
|
||||
def prune_heads(self, heads):
|
||||
mask = torch.ones(self.self.n_heads, self.self.d_head)
|
||||
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
|
||||
for head in heads:
|
||||
mask[head] = 0
|
||||
mask = mask.view(-1).contiguous().eq(1)
|
||||
@@ -403,7 +402,7 @@ class BertAttention(nn.Module):
|
||||
self.self.query = prune_linear_layer(self.self.query, index)
|
||||
self.self.key = prune_linear_layer(self.self.key, index)
|
||||
self.self.value = prune_linear_layer(self.self.value, index)
|
||||
self.output.dense = prune_linear_layer(self.output.dense, index, dim=0)
|
||||
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
||||
# Update hyper params
|
||||
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
|
||||
Reference in New Issue
Block a user