prune both attention and self-attention heads
This commit is contained in:
@@ -633,7 +633,7 @@ class BertModel(BertPreTrainedModel):
|
|||||||
See base class PreTrainedModel
|
See base class PreTrainedModel
|
||||||
"""
|
"""
|
||||||
for layer, heads in heads_to_prune.items():
|
for layer, heads in heads_to_prune.items():
|
||||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
self.encoder.layer[layer].self_attention.prune_heads(heads)
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
@@ -736,7 +736,8 @@ class BertDecoderModel(BertPreTrainedModel):
|
|||||||
See base class PreTrainedModel
|
See base class PreTrainedModel
|
||||||
"""
|
"""
|
||||||
for layer, heads in heads_to_prune.items():
|
for layer, heads in heads_to_prune.items():
|
||||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
self.decoder.layer[layer].attention.prune_heads(heads)
|
||||||
|
self.decoder.layer[layer].self_attention.prune_heads(heads)
|
||||||
|
|
||||||
def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user