Pruning saved to configuration first try

This commit is contained in:
Lysandre
2019-08-19 22:43:02 -04:00
committed by LysandreJik
parent d7a4c3252e
commit 42e00cf9e1
3 changed files with 72 additions and 0 deletions

View File

@@ -104,6 +104,7 @@ class PretrainedConfig(object):
self.output_attentions = kwargs.pop('output_attentions', False)
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
self.torchscript = kwargs.pop('torchscript', False)
self.pruned_heads = kwargs.pop('pruned_heads', {})
def save_pretrained(self, save_directory):
""" Save a configuration object to the directory `save_directory`, so that it
@@ -363,6 +364,15 @@ class PreTrainedModel(nn.Module):
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
"""
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
for layer, heads in heads_to_prune.items():
if str(layer) not in self.config.pruned_heads:
self.config.pruned_heads[str(layer)] = heads
else:
for head in heads:
if head not in self.config.pruned_heads[str(layer)]:
self.config.pruned_heads[str(layer)].append(head)
base_model._prune_heads(heads_to_prune)
def save_pretrained(self, save_directory):