updated pruning logic with sets - Bert and GPT-2

This commit is contained in:
thomwolf
2019-08-31 01:59:07 +02:00
committed by LysandreJik
parent 0c8e823b03
commit bdb4409ed8
3 changed files with 45 additions and 61 deletions

View File

@@ -202,8 +202,7 @@ class PretrainedConfig(object):
config = cls.from_json_file(resolved_config_file)
if hasattr(config, 'pruned_heads'):
config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}
config.pruned_heads = dict((int(key), set(value)) for key, value in config.pruned_heads.items())
# Update config with kwargs if needed
to_remove = []
@@ -316,7 +315,7 @@ class PreTrainedModel(nn.Module):
new_embeddings.to(old_embeddings.weight.device)
# initialize all new embeddings (in particular added tokens)
self.init_weights(new_embeddings)
self._init_weights(new_embeddings)
# Copy word embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
@@ -360,36 +359,31 @@ class PreTrainedModel(nn.Module):
return model_embeds
def init_weights(self):
""" Initialize and prunes weights if needed. """
# Initialize weights
self.apply(self._init_weights)
# Prune heads if needed
if self.config.pruned_heads:
self.prune_heads(self.config.pruned_heads)
def prune_heads(self, heads_to_prune):
""" Prunes heads of the base model.
Arguments:
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
"""
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
to_be_pruned = {}
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
for layer, heads in heads_to_prune.items():
if int(layer) not in self.config.pruned_heads:
self.config.pruned_heads[int(layer)] = heads
to_be_pruned[int(layer)] = heads
else:
for head in heads:
if head not in self.config.pruned_heads[int(layer)]:
self.config.pruned_heads[int(layer)].append(head)
union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
if int(layer) in to_be_pruned:
to_be_pruned[int(layer)].append(head)
else:
to_be_pruned[int(layer)] = [head]
else:
logger.warning("Tried to remove head " + str(head) +
" of layer " + str(layer) +
" but it was already removed. The current removed heads are " + str(heads_to_prune))
base_model._prune_heads(to_be_pruned)
base_model._prune_heads(heads_to_prune)
def save_pretrained(self, save_directory):
""" Save a model and its configuration file to a directory, so that it