[PyTorch] Refactor Resize Token Embeddings (#8880)
* fix resize tokens
* correct mobile_bert
* move embedding fix into modeling_utils.py
* refactor
* fix lm head resize
* refactor
* break lines to make sylvain happy
* add news tests
* fix typo
* improve test
* skip bart-like for now
* check if base_model = get(...) is necessary
* clean files
* improve test
* fix tests
* revert style templates
* Update templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
This commit is contained in:
committed by
GitHub
parent
e52f9c0ade
commit
443f67e887
@@ -605,14 +605,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
Return:
|
||||
:obj:`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
|
||||
"""
|
||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
||||
model_embeds = base_model._resize_token_embeddings(new_num_tokens)
|
||||
model_embeds = self._resize_token_embeddings(new_num_tokens)
|
||||
if new_num_tokens is None:
|
||||
return model_embeds
|
||||
|
||||
# Update base model and current model config
|
||||
self.config.vocab_size = new_num_tokens
|
||||
base_model.vocab_size = new_num_tokens
|
||||
self.vocab_size = new_num_tokens
|
||||
|
||||
# Tie weights again if needed
|
||||
self.tie_weights()
|
||||
@@ -623,6 +622,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
old_embeddings = self.get_input_embeddings()
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
self.set_input_embeddings(new_embeddings)
|
||||
|
||||
# if word embeddings are not tied, make sure that lm head is resized as well
|
||||
if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
|
||||
old_lm_head = self.get_output_embeddings()
|
||||
new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
|
||||
self.set_output_embeddings(new_lm_head)
|
||||
|
||||
return self.get_input_embeddings()
|
||||
|
||||
def _get_resized_embeddings(
|
||||
@@ -653,9 +659,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
if old_num_tokens == new_num_tokens:
|
||||
return old_embeddings
|
||||
|
||||
if not isinstance(old_embeddings, nn.Embedding):
|
||||
raise TypeError(
|
||||
f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}."
|
||||
f"You should either use a different resize function or make sure that `old_embeddings` are an instance of {nn.Embedding}."
|
||||
)
|
||||
|
||||
# Build new embeddings
|
||||
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
|
||||
new_embeddings.to(old_embeddings.weight.device)
|
||||
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim).to(self.device)
|
||||
|
||||
# initialize all new embeddings (in particular added tokens)
|
||||
self._init_weights(new_embeddings)
|
||||
@@ -666,6 +677,68 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
|
||||
return new_embeddings
|
||||
|
||||
def _get_resized_lm_head(
|
||||
self, old_lm_head: torch.nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False
|
||||
) -> torch.nn.Linear:
|
||||
"""
|
||||
Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized
|
||||
vectors at the end. Reducing the size will remove vectors from the end
|
||||
|
||||
Args:
|
||||
old_lm_head (:obj:`torch.nn.Linear`):
|
||||
Old lm head liner layer to be resized.
|
||||
new_num_tokens (:obj:`int`, `optional`):
|
||||
New number of tokens in the linear matrix.
|
||||
|
||||
Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
|
||||
vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens
|
||||
:obj:`torch.nn.Linear`` module of the model without doing anything.
|
||||
transposed (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether ``old_lm_head`` is transposed or not. If True ``old_lm_head.size()`` is ``lm_head_dim,
|
||||
vocab_size`` else ``vocab_size, lm_head_dim``.
|
||||
|
||||
Return:
|
||||
:obj:`torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if
|
||||
:obj:`new_num_tokens` is :obj:`None`
|
||||
"""
|
||||
if new_num_tokens is None:
|
||||
return old_lm_head
|
||||
|
||||
old_num_tokens, old_lm_head_dim = (
|
||||
old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
|
||||
)
|
||||
|
||||
if old_num_tokens == new_num_tokens:
|
||||
return old_lm_head
|
||||
|
||||
if not isinstance(old_lm_head, nn.Linear):
|
||||
raise TypeError(
|
||||
f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Linear}."
|
||||
f"You should either use a different resize function or make sure that `old_embeddings` are an instance of {nn.Linear}."
|
||||
)
|
||||
|
||||
# Build new lm head
|
||||
new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
|
||||
has_new_lm_head_bias = old_lm_head.bias is not None
|
||||
new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias).to(self.device)
|
||||
|
||||
# initialize new lm head (in particular added tokens)
|
||||
self._init_weights(new_lm_head)
|
||||
|
||||
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
||||
|
||||
# Copy old lm head weights to new lm head
|
||||
if not transposed:
|
||||
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
|
||||
else:
|
||||
new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
|
||||
|
||||
# Copy bias weights to new lm head
|
||||
if has_new_lm_head_bias:
|
||||
new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
|
||||
|
||||
return new_lm_head
|
||||
|
||||
def init_weights(self):
|
||||
"""
|
||||
Initializes and prunes weights if needed.
|
||||
|
||||
@@ -632,12 +632,6 @@ class AlbertModel(AlbertPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
old_embeddings = self.embeddings.word_embeddings
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
self.embeddings.word_embeddings = new_embeddings
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
"""
|
||||
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has
|
||||
@@ -748,6 +742,9 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.predictions.decoder
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.predictions.decoder = new_embeddings
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.albert.embeddings.word_embeddings
|
||||
|
||||
@@ -889,6 +886,9 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.predictions.decoder
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.predictions.decoder = new_embeddings
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.albert.embeddings.word_embeddings
|
||||
|
||||
|
||||
@@ -905,6 +905,9 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@@ -1010,6 +1013,9 @@ class BertLMHeadModel(BertPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@@ -1131,6 +1137,9 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
|
||||
@@ -422,6 +422,9 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head.decoder
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head.decoder = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
|
||||
@@ -496,6 +496,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs):
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
|
||||
@@ -508,6 +508,9 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.vocab_projector
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.vocab_projector = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
|
||||
@@ -1003,6 +1003,9 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.generator_lm_head
|
||||
|
||||
def set_output_embeddings(self, word_embeddings):
|
||||
self.generator_lm_head = word_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
|
||||
@@ -194,6 +194,9 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.decoder.get_output_embeddings()
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
return self.decoder.set_output_embeddings(new_embeddings)
|
||||
|
||||
@classmethod
|
||||
def from_encoder_decoder_pretrained(
|
||||
cls,
|
||||
|
||||
@@ -1167,6 +1167,9 @@ class FunnelForMaskedLM(FunnelPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
|
||||
@@ -816,6 +816,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
@@ -945,6 +948,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
|
||||
@@ -781,6 +781,9 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
|
||||
@@ -1632,6 +1632,9 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head.decoder
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head.decoder = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=LongformerMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
|
||||
@@ -641,7 +641,7 @@ class MobileBertLMPredictionHead(nn.Module):
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0))
|
||||
hidden_states += self.bias
|
||||
hidden_states += self.decoder.bias
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -949,26 +949,16 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
Tie the weights between the input embeddings and the output embeddings. If the `torchscript` flag is set in the
|
||||
configuration, can't handle parameter sharing so we are cloning the weights instead.
|
||||
"""
|
||||
output_embeddings = self.get_output_embeddings()
|
||||
input_embeddings = self.get_input_embeddings()
|
||||
def set_output_embeddings(self, new_embeddigs):
|
||||
self.cls.predictions.decoder = new_embeddigs
|
||||
|
||||
resized_dense = nn.Linear(
|
||||
input_embeddings.num_embeddings, self.config.hidden_size - self.config.embedding_size, bias=False
|
||||
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding:
|
||||
# resize dense output embedings at first
|
||||
self.cls.predictions.dense = self._get_resized_lm_head(
|
||||
self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True
|
||||
)
|
||||
kept_data = self.cls.predictions.dense.weight.data[
|
||||
..., : min(self.cls.predictions.dense.weight.data.shape[1], resized_dense.weight.data.shape[1])
|
||||
]
|
||||
resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data
|
||||
self.cls.predictions.dense = resized_dense
|
||||
self.cls.predictions.dense.to(self.device)
|
||||
|
||||
if output_embeddings is not None and self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
||||
return super().resize_token_embeddings(new_num_tokens=new_num_tokens)
|
||||
|
||||
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=MobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@@ -1067,26 +1057,15 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
def tie_weights(self):
|
||||
"""
|
||||
Tie the weights between the input embeddings and the output embeddings. If the `torchscript` flag is set in the
|
||||
configuration, can't handle parameter sharing so we are cloning the weights instead.
|
||||
"""
|
||||
output_embeddings = self.get_output_embeddings()
|
||||
input_embeddings = self.get_input_embeddings()
|
||||
def set_output_embeddings(self, new_embeddigs):
|
||||
self.cls.predictions.decoder = new_embeddigs
|
||||
|
||||
resized_dense = nn.Linear(
|
||||
input_embeddings.num_embeddings, self.config.hidden_size - self.config.embedding_size, bias=False
|
||||
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding:
|
||||
# resize dense output embedings at first
|
||||
self.cls.predictions.dense = self._get_resized_lm_head(
|
||||
self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True
|
||||
)
|
||||
kept_data = self.cls.predictions.dense.weight.data[
|
||||
..., : min(self.cls.predictions.dense.weight.data.shape[1], resized_dense.weight.data.shape[1])
|
||||
]
|
||||
resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data
|
||||
self.cls.predictions.dense = resized_dense
|
||||
self.cls.predictions.dense.to(self.device)
|
||||
|
||||
if output_embeddings is not None and self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
||||
return super().resize_token_embeddings(new_num_tokens=new_num_tokens)
|
||||
|
||||
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
|
||||
@@ -542,6 +542,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
@@ -628,6 +631,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=OpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
|
||||
@@ -1703,6 +1703,9 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.prophetnet.word_embeddings
|
||||
|
||||
@@ -1901,6 +1904,9 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(PROPHETNET_STANDALONE_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=ProphetNetDecoderLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
|
||||
@@ -1459,6 +1459,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.rag.generator.get_output_embeddings()
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
return self.rag.generator.set_output_embeddings(new_embeddings)
|
||||
|
||||
def shift_tokens_right(self, input_ids, start_token_id=None):
|
||||
"""Shift input ids one token to the right, and pad with start_token_id"""
|
||||
if start_token_id is None:
|
||||
|
||||
@@ -2197,6 +2197,9 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head.decoder
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head.decoder = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
@@ -2309,6 +2312,9 @@ class ReformerForMaskedLM(ReformerPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head.decoder
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head.decoder = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
|
||||
@@ -752,6 +752,9 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head.decoder
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head.decoder = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@@ -873,6 +876,9 @@ class RobertaForMaskedLM(RobertaPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head.decoder
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head.decoder = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
|
||||
@@ -655,6 +655,9 @@ class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
|
||||
@@ -1363,6 +1363,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
self.encoder.set_input_embeddings(new_embeddings)
|
||||
self.decoder.set_input_embeddings(new_embeddings)
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
|
||||
@@ -688,6 +688,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.pred_layer.proj
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.pred_layer.proj = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||
mask_token_id = self.config.mask_token_id
|
||||
lang_id = self.config.lang_id
|
||||
|
||||
@@ -1312,6 +1312,9 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_loss
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_loss = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, use_mems=None, **kwargs):
|
||||
# Add dummy token at the end (no attention on this one)
|
||||
|
||||
|
||||
@@ -781,6 +781,9 @@ class {{cookiecutter.camelcase_modelname}}ForMaskedLM({{cookiecutter.camelcase_m
|
||||
def get_output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
|
||||
@@ -208,6 +208,10 @@ class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO: Decoder embeddings cannot be resized at the moment")
|
||||
def test_resize_embeddings_untied(self):
|
||||
pass
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
def test_tiny_model(self):
|
||||
|
||||
@@ -128,6 +128,10 @@ class BlenderbotTesterMixin(ModelTesterMixin, unittest.TestCase):
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO: Decoder embeddings cannot be resized at the moment")
|
||||
def test_resize_embeddings_untied(self):
|
||||
pass
|
||||
|
||||
|
||||
@unittest.skipUnless(torch_device != "cpu", "3B test too slow on CPU.")
|
||||
@require_torch
|
||||
|
||||
@@ -815,6 +815,10 @@ class ModelTesterMixin:
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
# Input ids should be clamped to the maximum size of the vocabulary
|
||||
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
||||
|
||||
# make sure that decoder_input_ids are resized as well
|
||||
if "decoder_input_ids" in inputs_dict:
|
||||
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
|
||||
@@ -825,6 +829,57 @@ class ModelTesterMixin:
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def test_resize_embeddings_untied(self):
|
||||
(
|
||||
original_config,
|
||||
inputs_dict,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.test_resize_embeddings:
|
||||
return
|
||||
|
||||
original_config.tie_word_embeddings = False
|
||||
|
||||
# if model cannot untied embeddings -> leave test
|
||||
if original_config.tie_word_embeddings:
|
||||
return
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config = copy.deepcopy(original_config)
|
||||
model = model_class(config).to(torch_device)
|
||||
|
||||
# if no output embeddings -> leave test
|
||||
if model.get_output_embeddings() is None:
|
||||
continue
|
||||
|
||||
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||
model_vocab_size = config.vocab_size
|
||||
model.resize_token_embeddings(model_vocab_size + 10)
|
||||
self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
|
||||
output_embeds = model.get_output_embeddings()
|
||||
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
|
||||
# Check bias if present
|
||||
if output_embeds.bias is not None:
|
||||
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10)
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
||||
model.resize_token_embeddings(model_vocab_size - 15)
|
||||
self.assertEqual(model.config.vocab_size, model_vocab_size - 15)
|
||||
# Check that it actually resizes the embeddings matrix
|
||||
output_embeds = model.get_output_embeddings()
|
||||
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15)
|
||||
# Check bias if present
|
||||
if output_embeds.bias is not None:
|
||||
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15)
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
# Input ids should be clamped to the maximum size of the vocabulary
|
||||
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
||||
if "decoder_input_ids" in inputs_dict:
|
||||
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
@@ -226,15 +226,9 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
def test_tie_model_weights(self):
|
||||
pass
|
||||
|
||||
# def test_auto_model(self):
|
||||
# # XXX: add a tiny model to s3?
|
||||
# model_name = "facebook/wmt19-ru-en-tiny"
|
||||
# tiny = AutoModel.from_pretrained(model_name) # same vocab size
|
||||
# tok = AutoTokenizer.from_pretrained(model_name) # same tokenizer
|
||||
# inputs_dict = tok.batch_encode_plus(["Hello my friends"], return_tensors="pt")
|
||||
|
||||
# with torch.no_grad():
|
||||
# tiny(**inputs_dict)
|
||||
@unittest.skip("TODO: Decoder embeddings cannot be resized at the moment")
|
||||
def test_resize_embeddings_untied(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
|
||||
@@ -574,6 +574,10 @@ class ReformerTesterMixin:
|
||||
# reformer cannot keep gradients in attentions or hidden states
|
||||
return
|
||||
|
||||
def test_resize_embeddings_untied(self):
|
||||
# reformer cannot resize embeddings that easily
|
||||
return
|
||||
|
||||
|
||||
@require_torch
|
||||
class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@@ -444,6 +444,20 @@ class T5ModelTester:
|
||||
)
|
||||
)
|
||||
|
||||
def check_resize_embeddings_t5_v1_1(
|
||||
self,
|
||||
config,
|
||||
):
|
||||
prev_vocab_size = config.vocab_size
|
||||
|
||||
config.tie_word_embeddings = False
|
||||
model = T5ForConditionalGeneration(config=config).to(torch_device).eval()
|
||||
model.resize_token_embeddings(prev_vocab_size - 10)
|
||||
|
||||
self.parent.assertEqual(model.get_input_embeddings().weight.shape[0], prev_vocab_size - 10)
|
||||
self.parent.assertEqual(model.get_output_embeddings().weight.shape[0], prev_vocab_size - 10)
|
||||
self.parent.assertEqual(model.config.vocab_size, prev_vocab_size - 10)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@@ -480,7 +494,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
)
|
||||
test_pruning = False
|
||||
test_torchscript = True
|
||||
test_resize_embeddings = False
|
||||
test_resize_embeddings = True
|
||||
test_model_parallel = True
|
||||
is_encoder_decoder = True
|
||||
|
||||
@@ -536,6 +550,10 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
|
||||
|
||||
def test_v1_1_resize_embeddings(self):
|
||||
config = self.model_tester.prepare_config_and_inputs()[0]
|
||||
self.model_tester.check_resize_embeddings_t5_v1_1(config)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in T5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
||||
@@ -299,6 +299,10 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
|
||||
self.assertEqual(model_vocab_size, model.config.vocab_size)
|
||||
self.assertEqual(model_embed.emb_layers[layer].weight.shape[0], cloned_embeddings[layer].shape[0])
|
||||
|
||||
def test_resize_embeddings_untied(self):
|
||||
# transfo-xl requires special resize for lm-head
|
||||
return
|
||||
|
||||
|
||||
@require_torch
|
||||
class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user