embeddings resizing + tie_weights

This commit is contained in:
thomwolf
2019-07-12 00:02:49 +02:00
parent 50e62a4cb4
commit bd404735a7
15 changed files with 196 additions and 332 deletions

View File

@@ -151,6 +151,7 @@ class PreTrainedModel(nn.Module):
pretrained_model_archive_map = {}
load_tf_weights = lambda model, config, path: None
base_model_prefix = ""
input_embeddings = None
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedModel, self).__init__()
@@ -164,12 +165,48 @@ class PreTrainedModel(nn.Module):
# Save config in model
self.config = config
def _get_resized_embeddings(self, old_embeddings, new_num_tokens):
# Build new embeddings
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
new_embeddings.to(old_embeddings.weight.device)
# initialize all new embeddings (in particular added tokens)
self.init_weights(new_embeddings)
# Copy word embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
return new_embeddings
def resize_token_embeddings(self, new_num_tokens):
""" Resize input token embeddings matrix.
Args:
new_num_tokens: New number of tokens in the embedding matrix.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end
"""
if new_num_tokens == self.config.vocab_size:
return
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
base_model._resize_token_embeddings(new_num_tokens)
# Update base model and current model config
self.config.vocab_size = new_num_tokens
base_model.vocab_size = new_num_tokens
# Tie weights again if needed
if hasattr(self, 'tie_weights'):
self.tie_weights()
def prune_heads(self, heads_to_prune):
""" Prunes heads of the base model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
model_to_prune = getattr(self, self.base_model_prefix, self) # get the base model if needed
model_to_prune._prune_heads(heads_to_prune)
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
base_model._prune_heads(heads_to_prune)
def save_pretrained(self, save_directory):
""" Save a model with its configuration file to a directory, so that it