updated pruning logic with sets - Bert and GPT-2
This commit is contained in:
@@ -202,8 +202,7 @@ class PretrainedConfig(object):
|
||||
config = cls.from_json_file(resolved_config_file)
|
||||
|
||||
if hasattr(config, 'pruned_heads'):
|
||||
config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}
|
||||
|
||||
config.pruned_heads = dict((int(key), set(value)) for key, value in config.pruned_heads.items())
|
||||
|
||||
# Update config with kwargs if needed
|
||||
to_remove = []
|
||||
@@ -316,7 +315,7 @@ class PreTrainedModel(nn.Module):
|
||||
new_embeddings.to(old_embeddings.weight.device)
|
||||
|
||||
# initialize all new embeddings (in particular added tokens)
|
||||
self.init_weights(new_embeddings)
|
||||
self._init_weights(new_embeddings)
|
||||
|
||||
# Copy word embeddings from the previous weights
|
||||
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
||||
@@ -360,36 +359,31 @@ class PreTrainedModel(nn.Module):
|
||||
|
||||
return model_embeds
|
||||
|
||||
def init_weights(self):
|
||||
""" Initialize and prunes weights if needed. """
|
||||
# Initialize weights
|
||||
self.apply(self._init_weights)
|
||||
|
||||
# Prune heads if needed
|
||||
if self.config.pruned_heads:
|
||||
self.prune_heads(self.config.pruned_heads)
|
||||
|
||||
def prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the base model.
|
||||
|
||||
Arguments:
|
||||
|
||||
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
|
||||
E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
|
||||
to_be_pruned = {}
|
||||
|
||||
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
|
||||
for layer, heads in heads_to_prune.items():
|
||||
if int(layer) not in self.config.pruned_heads:
|
||||
self.config.pruned_heads[int(layer)] = heads
|
||||
to_be_pruned[int(layer)] = heads
|
||||
else:
|
||||
for head in heads:
|
||||
if head not in self.config.pruned_heads[int(layer)]:
|
||||
self.config.pruned_heads[int(layer)].append(head)
|
||||
union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
|
||||
self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
|
||||
|
||||
if int(layer) in to_be_pruned:
|
||||
to_be_pruned[int(layer)].append(head)
|
||||
else:
|
||||
to_be_pruned[int(layer)] = [head]
|
||||
else:
|
||||
logger.warning("Tried to remove head " + str(head) +
|
||||
" of layer " + str(layer) +
|
||||
" but it was already removed. The current removed heads are " + str(heads_to_prune))
|
||||
|
||||
base_model._prune_heads(to_be_pruned)
|
||||
base_model._prune_heads(heads_to_prune)
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save a model and its configuration file to a directory, so that it
|
||||
|
||||
Reference in New Issue
Block a user