Blocks deletion from already deleted heads. Necessary integration test.
Now raises a warning when a head to be deleted already has been deleted. An integration test verifying the total pipeline (-> from config -> save model -> load model -> additional head pruning) has been added.
This commit is contained in:
@@ -651,6 +651,7 @@ class BertModel(BertPreTrainedModel):
|
||||
|
||||
if hasattr(config, "pruned_heads"):
|
||||
pruned_heads = config.pruned_heads.copy().items()
|
||||
config.pruned_heads = {}
|
||||
for layer, heads in pruned_heads:
|
||||
if self.encoder.layer[int(layer)].attention.self.num_attention_heads == config.num_attention_heads:
|
||||
self.prune_heads({int(layer): list(map(int, heads))})
|
||||
|
||||
Reference in New Issue
Block a user