Blocks deletion from already deleted heads. Necessary integration test.

Now raises a warning when a head to be deleted already has been deleted. An integration test verifying the total pipeline (-> from config -> save model -> load model -> additional head pruning) has been added.
This commit is contained in:
LysandreJik
2019-08-21 21:20:39 -04:00
parent 719cb3738d
commit 87747518e9
6 changed files with 84 additions and 17 deletions

View File

@@ -455,6 +455,7 @@ class GPT2Model(GPT2PreTrainedModel):
if hasattr(config, "pruned_heads"):
pruned_heads = config.pruned_heads.copy().items()
config.pruned_heads = {}
for layer, heads in pruned_heads:
if self.h[int(layer)].attn.n_head == config.n_head:
self.prune_heads({int(layer): list(map(int, heads))})