Blocks deletion from already deleted heads. Necessary integration test.
Now raises a warning when a head to be deleted already has been deleted. An integration test verifying the total pipeline (-> from config -> save model -> load model -> additional head pruning) has been added.
This commit is contained in:
@@ -201,6 +201,10 @@ class PretrainedConfig(object):
|
||||
# Load config
|
||||
config = cls.from_json_file(resolved_config_file)
|
||||
|
||||
if hasattr(config, 'pruned_heads'):
|
||||
config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}
|
||||
|
||||
|
||||
# Update config with kwargs if needed
|
||||
to_remove = []
|
||||
for key, value in kwargs.items():
|
||||
@@ -365,15 +369,22 @@ class PreTrainedModel(nn.Module):
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
|
||||
to_be_pruned = {}
|
||||
|
||||
for layer, heads in heads_to_prune.items():
|
||||
if str(layer) not in self.config.pruned_heads:
|
||||
self.config.pruned_heads[str(layer)] = heads
|
||||
if int(layer) not in self.config.pruned_heads:
|
||||
self.config.pruned_heads[int(layer)] = heads
|
||||
to_be_pruned[int(layer)] = heads
|
||||
else:
|
||||
for head in heads:
|
||||
if head not in self.config.pruned_heads[str(layer)]:
|
||||
self.config.pruned_heads[str(layer)].append(head)
|
||||
if head not in self.config.pruned_heads[int(layer)]:
|
||||
self.config.pruned_heads[int(layer)].append(head)
|
||||
to_be_pruned[int(layer)].append(head)
|
||||
else:
|
||||
logger.warning(f"Tried to remove head {head} of layer {layer} but it was already removed. "
|
||||
f"The removed heads are {heads_to_prune}")
|
||||
|
||||
base_model._prune_heads(heads_to_prune)
|
||||
base_model._prune_heads(to_be_pruned)
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save a model and its configuration file to a directory, so that it
|
||||
|
||||
Reference in New Issue
Block a user