Blocks deletion from already deleted heads. Necessary integration test.

Now raises a warning when a head to be deleted already has been deleted. An integration test verifying the total pipeline (-> from config -> save model -> load model -> additional head pruning) has been added.
This commit is contained in:
LysandreJik
2019-08-21 21:20:39 -04:00
parent 719cb3738d
commit 87747518e9
6 changed files with 84 additions and 17 deletions

View File

@@ -561,6 +561,7 @@ class XLMModel(XLMPreTrainedModel):
if hasattr(config, "pruned_heads"):
pruned_heads = config.pruned_heads.copy().items()
config.pruned_heads = {}
for layer, heads in pruned_heads:
if self.attentions[int(layer)].n_heads == config.n_heads:
self.prune_heads({int(layer): list(map(int, heads))})