From 51261167b4a1de53cd38cc2b1553e5d71ba360ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 10 Oct 2019 12:17:22 +0200 Subject: [PATCH] prune both attention and self-attention heads --- transformers/modeling_bert.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index 9e03c2f8d4..fddf5d52a2 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -633,7 +633,7 @@ class BertModel(BertPreTrainedModel): See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): - self.encoder.layer[layer].attention.prune_heads(heads) + self.encoder.layer[layer].self_attention.prune_heads(heads) def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): if attention_mask is None: @@ -736,7 +736,8 @@ class BertDecoderModel(BertPreTrainedModel): See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): - self.encoder.layer[layer].attention.prune_heads(heads) + self.decoder.layer[layer].attention.prune_heads(heads) + self.decoder.layer[layer].self_attention.prune_heads(heads) def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): if attention_mask is None: