prune both attention and self-attention heads
This commit is contained in:
@@ -633,7 +633,7 @@ class BertModel(BertPreTrainedModel):
|
||||
See base class PreTrainedModel
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
self.encoder.layer[layer].self_attention.prune_heads(heads)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||
if attention_mask is None:
|
||||
@@ -736,7 +736,8 @@ class BertDecoderModel(BertPreTrainedModel):
|
||||
See base class PreTrainedModel
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
self.decoder.layer[layer].attention.prune_heads(heads)
|
||||
self.decoder.layer[layer].self_attention.prune_heads(heads)
|
||||
|
||||
def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||
if attention_mask is None:
|
||||
|
||||
Reference in New Issue
Block a user