From 7220d47a1c0d6b6c535e27bd1392a885eea842fd Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 17 Jun 2019 13:20:45 +0200 Subject: [PATCH] adding head pruning and tests --- pytorch_pretrained_bert/modeling.py | 7 +++-- tests/modeling_test.py | 42 +++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index b40e5825b3..9cf02d363c 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -51,12 +51,11 @@ PRETRAINED_MODEL_ARCHIVE_MAP = { BERT_CONFIG_NAME = 'bert_config.json' TF_WEIGHTS_NAME = 'model.ckpt' -def prune_linear_layer(layer, index, dim=-1): +def prune_linear_layer(layer, index, dim=0): """ Prune a linear layer (a model parameters) to keep only entries in index. Return the pruned layer as a new layer with requires_grad=True. Used to remove heads. """ - dim = (dim+100) % 2 index = index.to(layer.weight.device) W = layer.weight.index_select(dim, index).clone().detach() if layer.bias is not None: @@ -394,7 +393,7 @@ class BertAttention(nn.Module): self.output = BertSelfOutput(config) def prune_heads(self, heads): - mask = torch.ones(self.self.n_heads, self.self.d_head) + mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) for head in heads: mask[head] = 0 mask = mask.view(-1).contiguous().eq(1) @@ -403,7 +402,7 @@ class BertAttention(nn.Module): self.self.query = prune_linear_layer(self.self.query, index) self.self.key = prune_linear_layer(self.self.key, index) self.self.value = prune_linear_layer(self.self.value, index) - self.output.dense = prune_linear_layer(self.output.dense, index, dim=0) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) # Update hyper params self.self.num_attention_heads = self.self.num_attention_heads - len(heads) self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads diff --git a/tests/modeling_test.py b/tests/modeling_test.py index 4c78ead767..b23edf1aea 100644 --- a/tests/modeling_test.py +++ b/tests/modeling_test.py @@ -334,6 +334,47 @@ class BertModelTest(unittest.TestCase): self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads) + def create_and_check_bert_for_head_pruning(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): + for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction, + BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, + BertForTokenClassification): + if model_class in [BertForSequenceClassification, + BertForTokenClassification]: + model = model_class(config=config, + num_labels=self.num_labels, + keep_multihead_output=True) + else: + model = model_class(config=config, keep_multihead_output=True) + model.eval() + bert_model = model if isinstance(model, BertModel) else model.bert + heads_to_prune = {0: list(range(1, self.num_attention_heads)), + -1: [0]} + bert_model.prune_heads(heads_to_prune) + output = model(input_ids, token_type_ids, input_mask) + + if isinstance(model, BertModel): + output = sum(t.sum() for t in output[0]) + elif isinstance(output, (list, tuple)): + output = sum(t.sum() for t in output) + output = output.sum() + output.backward() + multihead_outputs = bert_model.get_multihead_outputs() + + self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers) + self.parent.assertListEqual( + list(multihead_outputs[0].size()), + [self.batch_size, 1, + self.seq_length, self.hidden_size // self.num_attention_heads]) + self.parent.assertListEqual( + list(multihead_outputs[1].size()), + [self.batch_size, self.num_attention_heads, + self.seq_length, self.hidden_size // self.num_attention_heads]) + self.parent.assertListEqual( + list(multihead_outputs[-1].size()), + [self.batch_size, self.num_attention_heads-1, + self.seq_length, self.hidden_size // self.num_attention_heads]) + + def test_default(self): self.run_tester(BertModelTest.BertModelTester(self)) @@ -394,6 +435,7 @@ class BertModelTest(unittest.TestCase): tester.create_and_check_bert_for_attentions(*config_and_inputs) tester.create_and_check_bert_for_headmasking(*config_and_inputs) + tester.create_and_check_bert_for_head_pruning(*config_and_inputs) @classmethod def ids_tensor(cls, shape, vocab_size, rng=None, name=None):