updating tests
This commit is contained in:
@@ -165,9 +165,27 @@ class PreTrainedModel(nn.Module):
|
||||
# Save config in model
|
||||
self.config = config
|
||||
|
||||
def _get_resized_embeddings(self, old_embeddings, new_num_tokens):
|
||||
# Build new embeddings
|
||||
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
|
||||
""" Build a resized Embedding Module from a provided token Embedding Module.
|
||||
Increasing the size will add newly initialized vectors at the end
|
||||
Reducing the size will remove vectors from the end
|
||||
|
||||
Args:
|
||||
new_num_tokens: (Optional) New number of tokens in the embedding matrix.
|
||||
Increasing the size will add newly initialized vectors at the end
|
||||
Reducing the size will remove vectors from the end
|
||||
If not provided or None: return the provided token Embedding Module.
|
||||
Return:
|
||||
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
|
||||
"""
|
||||
if new_num_tokens is None:
|
||||
return old_embeddings
|
||||
|
||||
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
||||
if old_num_tokens == new_num_tokens:
|
||||
return old_embeddings
|
||||
|
||||
# Build new embeddings
|
||||
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
|
||||
new_embeddings.to(old_embeddings.weight.device)
|
||||
|
||||
@@ -180,18 +198,29 @@ class PreTrainedModel(nn.Module):
|
||||
|
||||
return new_embeddings
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens):
|
||||
""" Resize input token embeddings matrix.
|
||||
def _tie_or_clone_weights(self, first_module, second_module):
|
||||
""" Tie or clone module weights depending of weither we are using TorchScript or not
|
||||
"""
|
||||
if self.config.torchscript:
|
||||
first_module.weight = nn.Parameter(second_module.weight.clone())
|
||||
else:
|
||||
first_module.weight = second_module.weight
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens=None):
|
||||
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
|
||||
|
||||
Args:
|
||||
new_num_tokens: New number of tokens in the embedding matrix.
|
||||
new_num_tokens: (Optional) New number of tokens in the embedding matrix.
|
||||
Increasing the size will add newly initialized vectors at the end
|
||||
Reducing the size will remove vectors from the end
|
||||
If not provided or None: does nothing.
|
||||
Return:
|
||||
Pointer to the input tokens Embedding Module of the model
|
||||
"""
|
||||
if new_num_tokens == self.config.vocab_size:
|
||||
return
|
||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
base_model._resize_token_embeddings(new_num_tokens)
|
||||
model_embeds = base_model._resize_token_embeddings(new_num_tokens)
|
||||
if new_num_tokens is None:
|
||||
return model_embeds
|
||||
|
||||
# Update base model and current model config
|
||||
self.config.vocab_size = new_num_tokens
|
||||
@@ -201,6 +230,8 @@ class PreTrainedModel(nn.Module):
|
||||
if hasattr(self, 'tie_weights'):
|
||||
self.tie_weights()
|
||||
|
||||
return model_embeds
|
||||
|
||||
def prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the base model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
|
||||
Reference in New Issue
Block a user