Pruning saved to configuration first try
This commit is contained in:
@@ -104,6 +104,7 @@ class PretrainedConfig(object):
|
||||
self.output_attentions = kwargs.pop('output_attentions', False)
|
||||
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
|
||||
self.torchscript = kwargs.pop('torchscript', False)
|
||||
self.pruned_heads = kwargs.pop('pruned_heads', {})
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save a configuration object to the directory `save_directory`, so that it
|
||||
@@ -363,6 +364,15 @@ class PreTrainedModel(nn.Module):
|
||||
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
|
||||
for layer, heads in heads_to_prune.items():
|
||||
if str(layer) not in self.config.pruned_heads:
|
||||
self.config.pruned_heads[str(layer)] = heads
|
||||
else:
|
||||
for head in heads:
|
||||
if head not in self.config.pruned_heads[str(layer)]:
|
||||
self.config.pruned_heads[str(layer)].append(head)
|
||||
|
||||
base_model._prune_heads(heads_to_prune)
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
|
||||
Reference in New Issue
Block a user