[PyTorch] Refactor Resize Token Embeddings (#8880)

* fix resize tokens

* correct mobile_bert

* move embedding fix into modeling_utils.py

* refactor

* fix lm head resize

* refactor

* break lines to make sylvain happy

* add news tests

* fix typo

* improve test

* skip bart-like for now

* check if base_model = get(...) is necessary

* clean files

* improve test

* fix tests

* revert style templates

* Update templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
This commit is contained in:
Patrick von Platen
2020-12-02 19:19:50 +01:00
committed by GitHub
parent e52f9c0ade
commit 443f67e887
30 changed files with 273 additions and 57 deletions

View File

@@ -605,14 +605,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
Return:
:obj:`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
"""
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
model_embeds = base_model._resize_token_embeddings(new_num_tokens)
model_embeds = self._resize_token_embeddings(new_num_tokens)
if new_num_tokens is None:
return model_embeds
# Update base model and current model config
self.config.vocab_size = new_num_tokens
base_model.vocab_size = new_num_tokens
self.vocab_size = new_num_tokens
# Tie weights again if needed
self.tie_weights()
@@ -623,6 +622,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
old_embeddings = self.get_input_embeddings()
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.set_input_embeddings(new_embeddings)
# if word embeddings are not tied, make sure that lm head is resized as well
if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
old_lm_head = self.get_output_embeddings()
new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
self.set_output_embeddings(new_lm_head)
return self.get_input_embeddings()
def _get_resized_embeddings(
@@ -653,9 +659,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
if old_num_tokens == new_num_tokens:
return old_embeddings
if not isinstance(old_embeddings, nn.Embedding):
raise TypeError(
f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}."
f"You should either use a different resize function or make sure that `old_embeddings` are an instance of {nn.Embedding}."
)
# Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
new_embeddings.to(old_embeddings.weight.device)
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim).to(self.device)
# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)
@@ -666,6 +677,68 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
return new_embeddings
def _get_resized_lm_head(
self, old_lm_head: torch.nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False
) -> torch.nn.Linear:
"""
Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end
Args:
old_lm_head (:obj:`torch.nn.Linear`):
Old lm head liner layer to be resized.
new_num_tokens (:obj:`int`, `optional`):
New number of tokens in the linear matrix.
Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens
:obj:`torch.nn.Linear`` module of the model without doing anything.
transposed (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether ``old_lm_head`` is transposed or not. If True ``old_lm_head.size()`` is ``lm_head_dim,
vocab_size`` else ``vocab_size, lm_head_dim``.
Return:
:obj:`torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if
:obj:`new_num_tokens` is :obj:`None`
"""
if new_num_tokens is None:
return old_lm_head
old_num_tokens, old_lm_head_dim = (
old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
)
if old_num_tokens == new_num_tokens:
return old_lm_head
if not isinstance(old_lm_head, nn.Linear):
raise TypeError(
f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Linear}."
f"You should either use a different resize function or make sure that `old_embeddings` are an instance of {nn.Linear}."
)
# Build new lm head
new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
has_new_lm_head_bias = old_lm_head.bias is not None
new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias).to(self.device)
# initialize new lm head (in particular added tokens)
self._init_weights(new_lm_head)
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
# Copy old lm head weights to new lm head
if not transposed:
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
else:
new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
# Copy bias weights to new lm head
if has_new_lm_head_bias:
new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
return new_lm_head
def init_weights(self):
"""
Initializes and prunes weights if needed.

View File

@@ -632,12 +632,6 @@ class AlbertModel(AlbertPreTrainedModel):
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.embeddings.word_embeddings
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.embeddings.word_embeddings = new_embeddings
return self.embeddings.word_embeddings
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has
@@ -748,6 +742,9 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
def get_output_embeddings(self):
return self.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.predictions.decoder = new_embeddings
def get_input_embeddings(self):
return self.albert.embeddings.word_embeddings
@@ -889,6 +886,9 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
def get_output_embeddings(self):
return self.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.predictions.decoder = new_embeddings
def get_input_embeddings(self):
return self.albert.embeddings.word_embeddings

View File

@@ -905,6 +905,9 @@ class BertForPreTraining(BertPreTrainedModel):
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
@@ -1010,6 +1013,9 @@ class BertLMHeadModel(BertPreTrainedModel):
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
@@ -1131,6 +1137,9 @@ class BertForMaskedLM(BertPreTrainedModel):
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,

View File

@@ -422,6 +422,9 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
@add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(

View File

@@ -496,6 +496,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if past:

View File

@@ -508,6 +508,9 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
def get_output_embeddings(self):
return self.vocab_projector
def set_output_embeddings(self, new_embeddings):
self.vocab_projector = new_embeddings
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,

View File

@@ -1003,6 +1003,9 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
def get_output_embeddings(self):
return self.generator_lm_head
def set_output_embeddings(self, word_embeddings):
self.generator_lm_head = word_embeddings
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,

View File

@@ -194,6 +194,9 @@ class EncoderDecoderModel(PreTrainedModel):
def get_output_embeddings(self):
return self.decoder.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
return self.decoder.set_output_embeddings(new_embeddings)
@classmethod
def from_encoder_decoder_pretrained(
cls,

View File

@@ -1167,6 +1167,9 @@ class FunnelForMaskedLM(FunnelPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,

View File

@@ -816,6 +816,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
@@ -945,6 +948,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs

View File

@@ -781,6 +781,9 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,

View File

@@ -1632,6 +1632,9 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=LongformerMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(

View File

@@ -641,7 +641,7 @@ class MobileBertLMPredictionHead(nn.Module):
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0))
hidden_states += self.bias
hidden_states += self.decoder.bias
return hidden_states
@@ -949,26 +949,16 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
def get_output_embeddings(self):
return self.cls.predictions.decoder
def tie_weights(self):
"""
Tie the weights between the input embeddings and the output embeddings. If the `torchscript` flag is set in the
configuration, can't handle parameter sharing so we are cloning the weights instead.
"""
output_embeddings = self.get_output_embeddings()
input_embeddings = self.get_input_embeddings()
def set_output_embeddings(self, new_embeddigs):
self.cls.predictions.decoder = new_embeddigs
resized_dense = nn.Linear(
input_embeddings.num_embeddings, self.config.hidden_size - self.config.embedding_size, bias=False
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding:
# resize dense output embedings at first
self.cls.predictions.dense = self._get_resized_lm_head(
self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True
)
kept_data = self.cls.predictions.dense.weight.data[
..., : min(self.cls.predictions.dense.weight.data.shape[1], resized_dense.weight.data.shape[1])
]
resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data
self.cls.predictions.dense = resized_dense
self.cls.predictions.dense.to(self.device)
if output_embeddings is not None and self.config.tie_word_embeddings:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
return super().resize_token_embeddings(new_num_tokens=new_num_tokens)
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
@@ -1067,26 +1057,15 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
def get_output_embeddings(self):
return self.cls.predictions.decoder
def tie_weights(self):
"""
Tie the weights between the input embeddings and the output embeddings. If the `torchscript` flag is set in the
configuration, can't handle parameter sharing so we are cloning the weights instead.
"""
output_embeddings = self.get_output_embeddings()
input_embeddings = self.get_input_embeddings()
def set_output_embeddings(self, new_embeddigs):
self.cls.predictions.decoder = new_embeddigs
resized_dense = nn.Linear(
input_embeddings.num_embeddings, self.config.hidden_size - self.config.embedding_size, bias=False
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding:
# resize dense output embedings at first
self.cls.predictions.dense = self._get_resized_lm_head(
self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True
)
kept_data = self.cls.predictions.dense.weight.data[
..., : min(self.cls.predictions.dense.weight.data.shape[1], resized_dense.weight.data.shape[1])
]
resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data
self.cls.predictions.dense = resized_dense
self.cls.predictions.dense.to(self.device)
if output_embeddings is not None and self.config.tie_word_embeddings:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
return super().resize_token_embeddings(new_num_tokens=new_num_tokens)
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(

View File

@@ -542,6 +542,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
@@ -628,6 +631,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=OpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(

View File

@@ -1703,6 +1703,9 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_input_embeddings(self):
return self.prophetnet.word_embeddings
@@ -1901,6 +1904,9 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
@add_start_docstrings_to_model_forward(PROPHETNET_STANDALONE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=ProphetNetDecoderLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(

View File

@@ -1459,6 +1459,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
def get_output_embeddings(self):
return self.rag.generator.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
return self.rag.generator.set_output_embeddings(new_embeddings)
def shift_tokens_right(self, input_ids, start_token_id=None):
"""Shift input ids one token to the right, and pad with start_token_id"""
if start_token_id is None:

View File

@@ -2197,6 +2197,9 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
@@ -2309,6 +2312,9 @@ class ReformerForMaskedLM(ReformerPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,

View File

@@ -752,6 +752,9 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
@@ -873,6 +876,9 @@ class RobertaForMaskedLM(RobertaPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head.decoder
def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,

View File

@@ -655,6 +655,9 @@ class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,

View File

@@ -1363,6 +1363,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_output_embeddings(self):
return self.lm_head

View File

@@ -688,6 +688,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def get_output_embeddings(self):
return self.pred_layer.proj
def set_output_embeddings(self, new_embeddings):
self.pred_layer.proj = new_embeddings
def prepare_inputs_for_generation(self, input_ids, **kwargs):
mask_token_id = self.config.mask_token_id
lang_id = self.config.lang_id

View File

@@ -1312,6 +1312,9 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def get_output_embeddings(self):
return self.lm_loss
def set_output_embeddings(self, new_embeddings):
self.lm_loss = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past=None, use_mems=None, **kwargs):
# Add dummy token at the end (no attention on this one)

View File

@@ -781,6 +781,9 @@ class {{cookiecutter.camelcase_modelname}}ForMaskedLM({{cookiecutter.camelcase_m
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,

View File

@@ -208,6 +208,10 @@ class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
def test_inputs_embeds(self):
pass
@unittest.skip("TODO: Decoder embeddings cannot be resized at the moment")
def test_resize_embeddings_untied(self):
pass
@require_sentencepiece
@require_tokenizers
def test_tiny_model(self):

View File

@@ -128,6 +128,10 @@ class BlenderbotTesterMixin(ModelTesterMixin, unittest.TestCase):
def test_feed_forward_chunking(self):
pass
@unittest.skip("TODO: Decoder embeddings cannot be resized at the moment")
def test_resize_embeddings_untied(self):
pass
@unittest.skipUnless(torch_device != "cpu", "3B test too slow on CPU.")
@require_torch

View File

@@ -815,6 +815,10 @@ class ModelTesterMixin:
# Check that the model can still do a forward pass successfully (every parameter should be resized)
# Input ids should be clamped to the maximum size of the vocabulary
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)
# make sure that decoder_input_ids are resized as well
if "decoder_input_ids" in inputs_dict:
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
@@ -825,6 +829,57 @@ class ModelTesterMixin:
self.assertTrue(models_equal)
def test_resize_embeddings_untied(self):
(
original_config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:
return
original_config.tie_word_embeddings = False
# if model cannot untied embeddings -> leave test
if original_config.tie_word_embeddings:
return
for model_class in self.all_model_classes:
config = copy.deepcopy(original_config)
model = model_class(config).to(torch_device)
# if no output embeddings -> leave test
if model.get_output_embeddings() is None:
continue
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model_vocab_size = config.vocab_size
model.resize_token_embeddings(model_vocab_size + 10)
self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
output_embeds = model.get_output_embeddings()
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
# Check bias if present
if output_embeds.bias is not None:
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model.resize_token_embeddings(model_vocab_size - 15)
self.assertEqual(model.config.vocab_size, model_vocab_size - 15)
# Check that it actually resizes the embeddings matrix
output_embeds = model.get_output_embeddings()
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15)
# Check bias if present
if output_embeds.bias is not None:
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
# Input ids should be clamped to the maximum size of the vocabulary
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)
if "decoder_input_ids" in inputs_dict:
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@@ -226,15 +226,9 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
def test_tie_model_weights(self):
pass
# def test_auto_model(self):
# # XXX: add a tiny model to s3?
# model_name = "facebook/wmt19-ru-en-tiny"
# tiny = AutoModel.from_pretrained(model_name) # same vocab size
# tok = AutoTokenizer.from_pretrained(model_name) # same tokenizer
# inputs_dict = tok.batch_encode_plus(["Hello my friends"], return_tensors="pt")
# with torch.no_grad():
# tiny(**inputs_dict)
@unittest.skip("TODO: Decoder embeddings cannot be resized at the moment")
def test_resize_embeddings_untied(self):
pass
@require_torch

View File

@@ -574,6 +574,10 @@ class ReformerTesterMixin:
# reformer cannot keep gradients in attentions or hidden states
return
def test_resize_embeddings_untied(self):
# reformer cannot resize embeddings that easily
return
@require_torch
class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, ModelTesterMixin, unittest.TestCase):

View File

@@ -444,6 +444,20 @@ class T5ModelTester:
)
)
def check_resize_embeddings_t5_v1_1(
self,
config,
):
prev_vocab_size = config.vocab_size
config.tie_word_embeddings = False
model = T5ForConditionalGeneration(config=config).to(torch_device).eval()
model.resize_token_embeddings(prev_vocab_size - 10)
self.parent.assertEqual(model.get_input_embeddings().weight.shape[0], prev_vocab_size - 10)
self.parent.assertEqual(model.get_output_embeddings().weight.shape[0], prev_vocab_size - 10)
self.parent.assertEqual(model.config.vocab_size, prev_vocab_size - 10)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@@ -480,7 +494,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
)
test_pruning = False
test_torchscript = True
test_resize_embeddings = False
test_resize_embeddings = True
test_model_parallel = True
is_encoder_decoder = True
@@ -536,6 +550,10 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
def test_v1_1_resize_embeddings(self):
config = self.model_tester.prepare_config_and_inputs()[0]
self.model_tester.check_resize_embeddings_t5_v1_1(config)
@slow
def test_model_from_pretrained(self):
for model_name in T5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:

View File

@@ -299,6 +299,10 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
self.assertEqual(model_vocab_size, model.config.vocab_size)
self.assertEqual(model_embed.emb_layers[layer].weight.shape[0], cloned_embeddings[layer].shape[0])
def test_resize_embeddings_untied(self):
# transfo-xl requires special resize for lm-head
return
@require_torch
class TransfoXLModelLanguageGenerationTest(unittest.TestCase):