Blocks deletion from already deleted heads. Necessary integration test.

Now raises a warning when a head to be deleted already has been deleted. An integration test verifying the total pipeline (-> from config -> save model -> load model -> additional head pruning) has been added.
This commit is contained in:
LysandreJik
2019-08-21 21:20:39 -04:00
parent 719cb3738d
commit 87747518e9
6 changed files with 84 additions and 17 deletions

View File

@@ -201,6 +201,10 @@ class PretrainedConfig(object):
# Load config
config = cls.from_json_file(resolved_config_file)
if hasattr(config, 'pruned_heads'):
config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}
# Update config with kwargs if needed
to_remove = []
for key, value in kwargs.items():
@@ -365,15 +369,22 @@ class PreTrainedModel(nn.Module):
"""
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
to_be_pruned = {}
for layer, heads in heads_to_prune.items():
if str(layer) not in self.config.pruned_heads:
self.config.pruned_heads[str(layer)] = heads
if int(layer) not in self.config.pruned_heads:
self.config.pruned_heads[int(layer)] = heads
to_be_pruned[int(layer)] = heads
else:
for head in heads:
if head not in self.config.pruned_heads[str(layer)]:
self.config.pruned_heads[str(layer)].append(head)
if head not in self.config.pruned_heads[int(layer)]:
self.config.pruned_heads[int(layer)].append(head)
to_be_pruned[int(layer)].append(head)
else:
logger.warning(f"Tried to remove head {head} of layer {layer} but it was already removed. "
f"The removed heads are {heads_to_prune}")
base_model._prune_heads(heads_to_prune)
base_model._prune_heads(to_be_pruned)
def save_pretrained(self, save_directory):
""" Save a model and its configuration file to a directory, so that it