From abb23a78bab17ec09dde4635c32f4aa21c15fa83 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Thu, 7 Nov 2019 17:09:16 +0000 Subject: [PATCH] Head pruning for ALBERT --- transformers/modeling_albert.py | 42 ++++++++++++++++++++++ transformers/tests/modeling_albert_test.py | 12 ++++--- 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/transformers/modeling_albert.py b/transformers/modeling_albert.py index 2540218e69..89ece4b61e 100644 --- a/transformers/modeling_albert.py +++ b/transformers/modeling_albert.py @@ -145,6 +145,29 @@ class AlbertAttention(BertSelfAttention): self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.pruned_heads = set() + def prune_heads(self, heads): + if len(heads) == 0: + return + mask = torch.ones(self.num_attention_heads, self.attention_head_size) + heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads + for head in heads: + # Compute how many pruned heads are before the head and move the index accordingly + head = head - sum(1 if h < head else 0 for h in self.pruned_heads) + mask[head] = 0 + mask = mask.view(-1).contiguous().eq(1) + index = torch.arange(len(mask))[mask].long() + + # Prune linear layers + self.query = prune_linear_layer(self.query, index) + self.key = prune_linear_layer(self.key, index) + self.value = prune_linear_layer(self.value, index) + self.dense = prune_linear_layer(self.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.num_attention_heads = self.num_attention_heads - len(heads) + self.all_head_size = self.attention_head_size * self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + def forward(self, input_ids, attention_mask=None, head_mask=None): mixed_query_layer = self.query(input_ids) mixed_key_layer = self.key(input_ids) @@ -409,6 +432,25 @@ class AlbertModel(AlbertPreTrainedModel): self.embeddings.word_embeddings = new_embeddings return self.embeddings.word_embeddings + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + ALBERT has a different architecture in that its layers are shared across groups, which then has inner groups. + If an ALBERT model has 12 hidden layers and 2 hidden groups, with two inner groups, there + is a total of 4 different layers. + + These layers are flattened: the indices [0,1] correspond to the two inner groups of the first hidden layer, + while [2,3] correspond to the two inner groups of the second hidden layer. + + Any layer with in index other than [0,1,2,3] will result in an error. + See base class PreTrainedModel for more information about head pruning + """ + for layer, heads in heads_to_prune.items(): + group_idx = int(layer / self.config.inner_group_num) + inner_group_idx = int(layer - group_idx * self.config.inner_group_num) + self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads) + + def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): if attention_mask is None: attention_mask = torch.ones_like(input_ids) diff --git a/transformers/tests/modeling_albert_test.py b/transformers/tests/modeling_albert_test.py index 46a2eeb729..979e0488eb 100644 --- a/transformers/tests/modeling_albert_test.py +++ b/transformers/tests/modeling_albert_test.py @@ -35,7 +35,6 @@ else: class AlbertModelTest(CommonTestCases.CommonModelTester): all_model_classes = (AlbertModel, AlbertForMaskedLM) if is_torch_available() else () - test_pruning = False test_head_masking = False class AlbertModelTester(object): @@ -49,9 +48,10 @@ class AlbertModelTest(CommonTestCases.CommonModelTester): use_token_type_ids=True, use_labels=True, vocab_size=99, - hidden_size=32, - num_hidden_layers=5, - num_attention_heads=4, + hidden_size=36, + num_hidden_layers=6, + num_hidden_groups=6, + num_attention_heads=6, intermediate_size=37, hidden_act="gelu", hidden_dropout_prob=0.1, @@ -86,6 +86,7 @@ class AlbertModelTest(CommonTestCases.CommonModelTester): self.num_labels = num_labels self.num_choices = num_choices self.scope = scope + self.num_hidden_groups = num_hidden_groups def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -117,7 +118,8 @@ class AlbertModelTest(CommonTestCases.CommonModelTester): attention_probs_dropout_prob=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, - initializer_range=self.initializer_range) + initializer_range=self.initializer_range, + num_hidden_groups=self.num_hidden_groups) return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels