embeddings resizing + tie_weights
This commit is contained in:
@@ -151,6 +151,7 @@ class PreTrainedModel(nn.Module):
|
||||
pretrained_model_archive_map = {}
|
||||
load_tf_weights = lambda model, config, path: None
|
||||
base_model_prefix = ""
|
||||
input_embeddings = None
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(PreTrainedModel, self).__init__()
|
||||
@@ -164,12 +165,48 @@ class PreTrainedModel(nn.Module):
|
||||
# Save config in model
|
||||
self.config = config
|
||||
|
||||
def _get_resized_embeddings(self, old_embeddings, new_num_tokens):
|
||||
# Build new embeddings
|
||||
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
||||
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
|
||||
new_embeddings.to(old_embeddings.weight.device)
|
||||
|
||||
# initialize all new embeddings (in particular added tokens)
|
||||
self.init_weights(new_embeddings)
|
||||
|
||||
# Copy word embeddings from the previous weights
|
||||
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
||||
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
|
||||
|
||||
return new_embeddings
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens):
|
||||
""" Resize input token embeddings matrix.
|
||||
|
||||
Args:
|
||||
new_num_tokens: New number of tokens in the embedding matrix.
|
||||
Increasing the size will add newly initialized vectors at the end
|
||||
Reducing the size will remove vectors from the end
|
||||
"""
|
||||
if new_num_tokens == self.config.vocab_size:
|
||||
return
|
||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
base_model._resize_token_embeddings(new_num_tokens)
|
||||
|
||||
# Update base model and current model config
|
||||
self.config.vocab_size = new_num_tokens
|
||||
base_model.vocab_size = new_num_tokens
|
||||
|
||||
# Tie weights again if needed
|
||||
if hasattr(self, 'tie_weights'):
|
||||
self.tie_weights()
|
||||
|
||||
def prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the base model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
"""
|
||||
model_to_prune = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
model_to_prune._prune_heads(heads_to_prune)
|
||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
base_model._prune_heads(heads_to_prune)
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save a model with its configuration file to a directory, so that it
|
||||
|
||||
Reference in New Issue
Block a user