adding head pruning and tests
This commit is contained in:
@@ -51,12 +51,11 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|||||||
BERT_CONFIG_NAME = 'bert_config.json'
|
BERT_CONFIG_NAME = 'bert_config.json'
|
||||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||||
|
|
||||||
def prune_linear_layer(layer, index, dim=-1):
|
def prune_linear_layer(layer, index, dim=0):
|
||||||
""" Prune a linear layer (a model parameters) to keep only entries in index.
|
""" Prune a linear layer (a model parameters) to keep only entries in index.
|
||||||
Return the pruned layer as a new layer with requires_grad=True.
|
Return the pruned layer as a new layer with requires_grad=True.
|
||||||
Used to remove heads.
|
Used to remove heads.
|
||||||
"""
|
"""
|
||||||
dim = (dim+100) % 2
|
|
||||||
index = index.to(layer.weight.device)
|
index = index.to(layer.weight.device)
|
||||||
W = layer.weight.index_select(dim, index).clone().detach()
|
W = layer.weight.index_select(dim, index).clone().detach()
|
||||||
if layer.bias is not None:
|
if layer.bias is not None:
|
||||||
@@ -394,7 +393,7 @@ class BertAttention(nn.Module):
|
|||||||
self.output = BertSelfOutput(config)
|
self.output = BertSelfOutput(config)
|
||||||
|
|
||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
mask = torch.ones(self.self.n_heads, self.self.d_head)
|
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
|
||||||
for head in heads:
|
for head in heads:
|
||||||
mask[head] = 0
|
mask[head] = 0
|
||||||
mask = mask.view(-1).contiguous().eq(1)
|
mask = mask.view(-1).contiguous().eq(1)
|
||||||
@@ -403,7 +402,7 @@ class BertAttention(nn.Module):
|
|||||||
self.self.query = prune_linear_layer(self.self.query, index)
|
self.self.query = prune_linear_layer(self.self.query, index)
|
||||||
self.self.key = prune_linear_layer(self.self.key, index)
|
self.self.key = prune_linear_layer(self.self.key, index)
|
||||||
self.self.value = prune_linear_layer(self.self.value, index)
|
self.self.value = prune_linear_layer(self.self.value, index)
|
||||||
self.output.dense = prune_linear_layer(self.output.dense, index, dim=0)
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
||||||
# Update hyper params
|
# Update hyper params
|
||||||
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
||||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||||
|
|||||||
@@ -334,6 +334,47 @@ class BertModelTest(unittest.TestCase):
|
|||||||
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
|
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
|
||||||
|
|
||||||
|
|
||||||
|
def create_and_check_bert_for_head_pruning(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||||
|
for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
|
||||||
|
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
|
||||||
|
BertForTokenClassification):
|
||||||
|
if model_class in [BertForSequenceClassification,
|
||||||
|
BertForTokenClassification]:
|
||||||
|
model = model_class(config=config,
|
||||||
|
num_labels=self.num_labels,
|
||||||
|
keep_multihead_output=True)
|
||||||
|
else:
|
||||||
|
model = model_class(config=config, keep_multihead_output=True)
|
||||||
|
model.eval()
|
||||||
|
bert_model = model if isinstance(model, BertModel) else model.bert
|
||||||
|
heads_to_prune = {0: list(range(1, self.num_attention_heads)),
|
||||||
|
-1: [0]}
|
||||||
|
bert_model.prune_heads(heads_to_prune)
|
||||||
|
output = model(input_ids, token_type_ids, input_mask)
|
||||||
|
|
||||||
|
if isinstance(model, BertModel):
|
||||||
|
output = sum(t.sum() for t in output[0])
|
||||||
|
elif isinstance(output, (list, tuple)):
|
||||||
|
output = sum(t.sum() for t in output)
|
||||||
|
output = output.sum()
|
||||||
|
output.backward()
|
||||||
|
multihead_outputs = bert_model.get_multihead_outputs()
|
||||||
|
|
||||||
|
self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(multihead_outputs[0].size()),
|
||||||
|
[self.batch_size, 1,
|
||||||
|
self.seq_length, self.hidden_size // self.num_attention_heads])
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(multihead_outputs[1].size()),
|
||||||
|
[self.batch_size, self.num_attention_heads,
|
||||||
|
self.seq_length, self.hidden_size // self.num_attention_heads])
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(multihead_outputs[-1].size()),
|
||||||
|
[self.batch_size, self.num_attention_heads-1,
|
||||||
|
self.seq_length, self.hidden_size // self.num_attention_heads])
|
||||||
|
|
||||||
|
|
||||||
def test_default(self):
|
def test_default(self):
|
||||||
self.run_tester(BertModelTest.BertModelTester(self))
|
self.run_tester(BertModelTest.BertModelTester(self))
|
||||||
|
|
||||||
@@ -394,6 +435,7 @@ class BertModelTest(unittest.TestCase):
|
|||||||
|
|
||||||
tester.create_and_check_bert_for_attentions(*config_and_inputs)
|
tester.create_and_check_bert_for_attentions(*config_and_inputs)
|
||||||
tester.create_and_check_bert_for_headmasking(*config_and_inputs)
|
tester.create_and_check_bert_for_headmasking(*config_and_inputs)
|
||||||
|
tester.create_and_check_bert_for_head_pruning(*config_and_inputs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
|
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
|
||||||
|
|||||||
Reference in New Issue
Block a user