[PyTorch] Refactor Resize Token Embeddings (#8880)
* fix resize tokens
* correct mobile_bert
* move embedding fix into modeling_utils.py
* refactor
* fix lm head resize
* refactor
* break lines to make sylvain happy
* add news tests
* fix typo
* improve test
* skip bart-like for now
* check if base_model = get(...) is necessary
* clean files
* improve test
* fix tests
* revert style templates
* Update templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
This commit is contained in:
committed by
GitHub
parent
e52f9c0ade
commit
443f67e887
@@ -605,14 +605,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
Return:
|
||||
:obj:`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
model_embeds = base_model._resize_token_embeddings(new_num_tokens)
|
||||
model_embeds = self._resize_token_embeddings(new_num_tokens)
|
||||
if new_num_tokens is None:
|
||||
return model_embeds
|
||||
|
||||
# Update base model and current model config
|
||||
self.config.vocab_size = new_num_tokens
|
||||
base_model.vocab_size = new_num_tokens
|
||||
self.vocab_size = new_num_tokens
|
||||
|
||||
# Tie weights again if needed
|
||||
self.tie_weights()
|
||||
@@ -623,6 +622,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
old_embeddings = self.get_input_embeddings()
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
self.set_input_embeddings(new_embeddings)
|
||||
|
||||
# if word embeddings are not tied, make sure that lm head is resized as well
|
||||
if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
|
||||
old_lm_head = self.get_output_embeddings()
|
||||
new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
|
||||
self.set_output_embeddings(new_lm_head)
|
||||
|
||||
return self.get_input_embeddings()
|
||||
|
||||
def _get_resized_embeddings(
|
||||
@@ -653,9 +659,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
if old_num_tokens == new_num_tokens:
|
||||
return old_embeddings
|
||||
|
||||
if not isinstance(old_embeddings, nn.Embedding):
|
||||
raise TypeError(
|
||||
f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}."
|
||||
f"You should either use a different resize function or make sure that `old_embeddings` are an instance of {nn.Embedding}."
|
||||
)
|
||||
|
||||
# Build new embeddings
|
||||
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
|
||||
new_embeddings.to(old_embeddings.weight.device)
|
||||
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim).to(self.device)
|
||||
|
||||
# initialize all new embeddings (in particular added tokens)
|
||||
self._init_weights(new_embeddings)
|
||||
@@ -666,6 +677,68 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
|
||||
return new_embeddings
|
||||
|
||||
def _get_resized_lm_head(
|
||||
self, old_lm_head: torch.nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False
|
||||
) -> torch.nn.Linear:
|
||||
"""
|
||||
Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized
|
||||
vectors at the end. Reducing the size will remove vectors from the end
|
||||
|
||||
Args:
|
||||
old_lm_head (:obj:`torch.nn.Linear`):
|
||||
Old lm head liner layer to be resized.
|
||||
new_num_tokens (:obj:`int`, `optional`):
|
||||
New number of tokens in the linear matrix.
|
||||
|
||||
Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
|
||||
vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens
|
||||
:obj:`torch.nn.Linear`` module of the model without doing anything.
|
||||
transposed (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether ``old_lm_head`` is transposed or not. If True ``old_lm_head.size()`` is ``lm_head_dim,
|
||||
vocab_size`` else ``vocab_size, lm_head_dim``.
|
||||
|
||||
Return:
|
||||
:obj:`torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if
|
||||
:obj:`new_num_tokens` is :obj:`None`
|
||||
"""
|
||||
if new_num_tokens is None:
|
||||
return old_lm_head
|
||||
|
||||
old_num_tokens, old_lm_head_dim = (
|
||||
old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
|
||||
)
|
||||
|
||||
if old_num_tokens == new_num_tokens:
|
||||
return old_lm_head
|
||||
|
||||
if not isinstance(old_lm_head, nn.Linear):
|
||||
raise TypeError(
|
||||
f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Linear}."
|
||||
f"You should either use a different resize function or make sure that `old_embeddings` are an instance of {nn.Linear}."
|
||||
)
|
||||
|
||||
# Build new lm head
|
||||
new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
|
||||
has_new_lm_head_bias = old_lm_head.bias is not None
|
||||
new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias).to(self.device)
|
||||
|
||||
# initialize new lm head (in particular added tokens)
|
||||
self._init_weights(new_lm_head)
|
||||
|
||||
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
||||
|
||||
# Copy old lm head weights to new lm head
|
||||
if not transposed:
|
||||
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
|
||||
else:
|
||||
new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
|
||||
|
||||
# Copy bias weights to new lm head
|
||||
if has_new_lm_head_bias:
|
||||
new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
|
||||
|
||||
return new_lm_head
|
||||
|
||||
def init_weights(self):
|
||||
"""
|
||||
Initializes and prunes weights if needed.
|
||||
|
||||
Reference in New Issue
Block a user