[PyTorch] Refactor Resize Token Embeddings (#8880)
* fix resize tokens
* correct mobile_bert
* move embedding fix into modeling_utils.py
* refactor
* fix lm head resize
* refactor
* break lines to make sylvain happy
* add news tests
* fix typo
* improve test
* skip bart-like for now
* check if base_model = get(...) is necessary
* clean files
* improve test
* fix tests
* revert style templates
* Update templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
This commit is contained in:
committed by
GitHub
parent
e52f9c0ade
commit
443f67e887
@@ -605,14 +605,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
Return:
|
Return:
|
||||||
:obj:`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
|
:obj:`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
|
||||||
"""
|
"""
|
||||||
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
model_embeds = self._resize_token_embeddings(new_num_tokens)
|
||||||
model_embeds = base_model._resize_token_embeddings(new_num_tokens)
|
|
||||||
if new_num_tokens is None:
|
if new_num_tokens is None:
|
||||||
return model_embeds
|
return model_embeds
|
||||||
|
|
||||||
# Update base model and current model config
|
# Update base model and current model config
|
||||||
self.config.vocab_size = new_num_tokens
|
self.config.vocab_size = new_num_tokens
|
||||||
base_model.vocab_size = new_num_tokens
|
self.vocab_size = new_num_tokens
|
||||||
|
|
||||||
# Tie weights again if needed
|
# Tie weights again if needed
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
@@ -623,6 +622,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
old_embeddings = self.get_input_embeddings()
|
old_embeddings = self.get_input_embeddings()
|
||||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||||
self.set_input_embeddings(new_embeddings)
|
self.set_input_embeddings(new_embeddings)
|
||||||
|
|
||||||
|
# if word embeddings are not tied, make sure that lm head is resized as well
|
||||||
|
if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
|
||||||
|
old_lm_head = self.get_output_embeddings()
|
||||||
|
new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
|
||||||
|
self.set_output_embeddings(new_lm_head)
|
||||||
|
|
||||||
return self.get_input_embeddings()
|
return self.get_input_embeddings()
|
||||||
|
|
||||||
def _get_resized_embeddings(
|
def _get_resized_embeddings(
|
||||||
@@ -653,9 +659,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
if old_num_tokens == new_num_tokens:
|
if old_num_tokens == new_num_tokens:
|
||||||
return old_embeddings
|
return old_embeddings
|
||||||
|
|
||||||
|
if not isinstance(old_embeddings, nn.Embedding):
|
||||||
|
raise TypeError(
|
||||||
|
f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}."
|
||||||
|
f"You should either use a different resize function or make sure that `old_embeddings` are an instance of {nn.Embedding}."
|
||||||
|
)
|
||||||
|
|
||||||
# Build new embeddings
|
# Build new embeddings
|
||||||
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
|
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim).to(self.device)
|
||||||
new_embeddings.to(old_embeddings.weight.device)
|
|
||||||
|
|
||||||
# initialize all new embeddings (in particular added tokens)
|
# initialize all new embeddings (in particular added tokens)
|
||||||
self._init_weights(new_embeddings)
|
self._init_weights(new_embeddings)
|
||||||
@@ -666,6 +677,68 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
|
|
||||||
return new_embeddings
|
return new_embeddings
|
||||||
|
|
||||||
|
def _get_resized_lm_head(
|
||||||
|
self, old_lm_head: torch.nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False
|
||||||
|
) -> torch.nn.Linear:
|
||||||
|
"""
|
||||||
|
Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized
|
||||||
|
vectors at the end. Reducing the size will remove vectors from the end
|
||||||
|
|
||||||
|
Args:
|
||||||
|
old_lm_head (:obj:`torch.nn.Linear`):
|
||||||
|
Old lm head liner layer to be resized.
|
||||||
|
new_num_tokens (:obj:`int`, `optional`):
|
||||||
|
New number of tokens in the linear matrix.
|
||||||
|
|
||||||
|
Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
|
||||||
|
vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens
|
||||||
|
:obj:`torch.nn.Linear`` module of the model without doing anything.
|
||||||
|
transposed (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether ``old_lm_head`` is transposed or not. If True ``old_lm_head.size()`` is ``lm_head_dim,
|
||||||
|
vocab_size`` else ``vocab_size, lm_head_dim``.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if
|
||||||
|
:obj:`new_num_tokens` is :obj:`None`
|
||||||
|
"""
|
||||||
|
if new_num_tokens is None:
|
||||||
|
return old_lm_head
|
||||||
|
|
||||||
|
old_num_tokens, old_lm_head_dim = (
|
||||||
|
old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
|
||||||
|
)
|
||||||
|
|
||||||
|
if old_num_tokens == new_num_tokens:
|
||||||
|
return old_lm_head
|
||||||
|
|
||||||
|
if not isinstance(old_lm_head, nn.Linear):
|
||||||
|
raise TypeError(
|
||||||
|
f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Linear}."
|
||||||
|
f"You should either use a different resize function or make sure that `old_embeddings` are an instance of {nn.Linear}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build new lm head
|
||||||
|
new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
|
||||||
|
has_new_lm_head_bias = old_lm_head.bias is not None
|
||||||
|
new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias).to(self.device)
|
||||||
|
|
||||||
|
# initialize new lm head (in particular added tokens)
|
||||||
|
self._init_weights(new_lm_head)
|
||||||
|
|
||||||
|
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
||||||
|
|
||||||
|
# Copy old lm head weights to new lm head
|
||||||
|
if not transposed:
|
||||||
|
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
|
||||||
|
else:
|
||||||
|
new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
|
||||||
|
|
||||||
|
# Copy bias weights to new lm head
|
||||||
|
if has_new_lm_head_bias:
|
||||||
|
new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
|
||||||
|
|
||||||
|
return new_lm_head
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
"""
|
"""
|
||||||
Initializes and prunes weights if needed.
|
Initializes and prunes weights if needed.
|
||||||
|
|||||||
@@ -632,12 +632,6 @@ class AlbertModel(AlbertPreTrainedModel):
|
|||||||
def set_input_embeddings(self, value):
|
def set_input_embeddings(self, value):
|
||||||
self.embeddings.word_embeddings = value
|
self.embeddings.word_embeddings = value
|
||||||
|
|
||||||
def _resize_token_embeddings(self, new_num_tokens):
|
|
||||||
old_embeddings = self.embeddings.word_embeddings
|
|
||||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
|
||||||
self.embeddings.word_embeddings = new_embeddings
|
|
||||||
return self.embeddings.word_embeddings
|
|
||||||
|
|
||||||
def _prune_heads(self, heads_to_prune):
|
def _prune_heads(self, heads_to_prune):
|
||||||
"""
|
"""
|
||||||
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has
|
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has
|
||||||
@@ -748,6 +742,9 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.predictions.decoder
|
return self.predictions.decoder
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.predictions.decoder = new_embeddings
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.albert.embeddings.word_embeddings
|
return self.albert.embeddings.word_embeddings
|
||||||
|
|
||||||
@@ -889,6 +886,9 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.predictions.decoder
|
return self.predictions.decoder
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.predictions.decoder = new_embeddings
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.albert.embeddings.word_embeddings
|
return self.albert.embeddings.word_embeddings
|
||||||
|
|
||||||
|
|||||||
@@ -905,6 +905,9 @@ class BertForPreTraining(BertPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.cls.predictions.decoder
|
return self.cls.predictions.decoder
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.cls.predictions.decoder = new_embeddings
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
@@ -1010,6 +1013,9 @@ class BertLMHeadModel(BertPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.cls.predictions.decoder
|
return self.cls.predictions.decoder
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.cls.predictions.decoder = new_embeddings
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
@@ -1131,6 +1137,9 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.cls.predictions.decoder
|
return self.cls.predictions.decoder
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.cls.predictions.decoder = new_embeddings
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
|
|||||||
@@ -422,6 +422,9 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head.decoder
|
return self.lm_head.decoder
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head.decoder = new_embeddings
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -496,6 +496,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs):
|
||||||
# only last token for inputs_ids if past is defined in kwargs
|
# only last token for inputs_ids if past is defined in kwargs
|
||||||
if past:
|
if past:
|
||||||
|
|||||||
@@ -508,6 +508,9 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.vocab_projector
|
return self.vocab_projector
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.vocab_projector = new_embeddings
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
|
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
|
|||||||
@@ -1003,6 +1003,9 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.generator_lm_head
|
return self.generator_lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, word_embeddings):
|
||||||
|
self.generator_lm_head = word_embeddings
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
|
|||||||
@@ -194,6 +194,9 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.decoder.get_output_embeddings()
|
return self.decoder.get_output_embeddings()
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
return self.decoder.set_output_embeddings(new_embeddings)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_encoder_decoder_pretrained(
|
def from_encoder_decoder_pretrained(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@@ -1167,6 +1167,9 @@ class FunnelForMaskedLM(FunnelPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
|
|||||||
@@ -816,6 +816,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||||
token_type_ids = kwargs.get("token_type_ids", None)
|
token_type_ids = kwargs.get("token_type_ids", None)
|
||||||
# only last token for inputs_ids if past is defined in kwargs
|
# only last token for inputs_ids if past is defined in kwargs
|
||||||
@@ -945,6 +948,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||||
token_type_ids = kwargs.get("token_type_ids", None)
|
token_type_ids = kwargs.get("token_type_ids", None)
|
||||||
# only last token for inputs_ids if past is defined in kwargs
|
# only last token for inputs_ids if past is defined in kwargs
|
||||||
|
|||||||
@@ -781,6 +781,9 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.cls.predictions.decoder
|
return self.cls.predictions.decoder
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.cls.predictions.decoder = new_embeddings
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
|
|||||||
@@ -1632,6 +1632,9 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head.decoder
|
return self.lm_head.decoder
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head.decoder = new_embeddings
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@replace_return_docstrings(output_type=LongformerMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=LongformerMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -641,7 +641,7 @@ class MobileBertLMPredictionHead(nn.Module):
|
|||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
hidden_states = self.transform(hidden_states)
|
hidden_states = self.transform(hidden_states)
|
||||||
hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0))
|
hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0))
|
||||||
hidden_states += self.bias
|
hidden_states += self.decoder.bias
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -949,26 +949,16 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.cls.predictions.decoder
|
return self.cls.predictions.decoder
|
||||||
|
|
||||||
def tie_weights(self):
|
def set_output_embeddings(self, new_embeddigs):
|
||||||
"""
|
self.cls.predictions.decoder = new_embeddigs
|
||||||
Tie the weights between the input embeddings and the output embeddings. If the `torchscript` flag is set in the
|
|
||||||
configuration, can't handle parameter sharing so we are cloning the weights instead.
|
|
||||||
"""
|
|
||||||
output_embeddings = self.get_output_embeddings()
|
|
||||||
input_embeddings = self.get_input_embeddings()
|
|
||||||
|
|
||||||
resized_dense = nn.Linear(
|
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding:
|
||||||
input_embeddings.num_embeddings, self.config.hidden_size - self.config.embedding_size, bias=False
|
# resize dense output embedings at first
|
||||||
|
self.cls.predictions.dense = self._get_resized_lm_head(
|
||||||
|
self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True
|
||||||
)
|
)
|
||||||
kept_data = self.cls.predictions.dense.weight.data[
|
|
||||||
..., : min(self.cls.predictions.dense.weight.data.shape[1], resized_dense.weight.data.shape[1])
|
|
||||||
]
|
|
||||||
resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data
|
|
||||||
self.cls.predictions.dense = resized_dense
|
|
||||||
self.cls.predictions.dense.to(self.device)
|
|
||||||
|
|
||||||
if output_embeddings is not None and self.config.tie_word_embeddings:
|
return super().resize_token_embeddings(new_num_tokens=new_num_tokens)
|
||||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@replace_return_docstrings(output_type=MobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=MobileBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
@@ -1067,26 +1057,15 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.cls.predictions.decoder
|
return self.cls.predictions.decoder
|
||||||
|
|
||||||
def tie_weights(self):
|
def set_output_embeddings(self, new_embeddigs):
|
||||||
"""
|
self.cls.predictions.decoder = new_embeddigs
|
||||||
Tie the weights between the input embeddings and the output embeddings. If the `torchscript` flag is set in the
|
|
||||||
configuration, can't handle parameter sharing so we are cloning the weights instead.
|
|
||||||
"""
|
|
||||||
output_embeddings = self.get_output_embeddings()
|
|
||||||
input_embeddings = self.get_input_embeddings()
|
|
||||||
|
|
||||||
resized_dense = nn.Linear(
|
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding:
|
||||||
input_embeddings.num_embeddings, self.config.hidden_size - self.config.embedding_size, bias=False
|
# resize dense output embedings at first
|
||||||
|
self.cls.predictions.dense = self._get_resized_lm_head(
|
||||||
|
self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True
|
||||||
)
|
)
|
||||||
kept_data = self.cls.predictions.dense.weight.data[
|
return super().resize_token_embeddings(new_num_tokens=new_num_tokens)
|
||||||
..., : min(self.cls.predictions.dense.weight.data.shape[1], resized_dense.weight.data.shape[1])
|
|
||||||
]
|
|
||||||
resized_dense.weight.data[..., : self.cls.predictions.dense.weight.data.shape[1]] = kept_data
|
|
||||||
self.cls.predictions.dense = resized_dense
|
|
||||||
self.cls.predictions.dense.to(self.device)
|
|
||||||
|
|
||||||
if output_embeddings is not None and self.config.tie_word_embeddings:
|
|
||||||
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
|
|||||||
@@ -542,6 +542,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
@@ -628,6 +631,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=OpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=OpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -1703,6 +1703,9 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.prophetnet.word_embeddings
|
return self.prophetnet.word_embeddings
|
||||||
|
|
||||||
@@ -1901,6 +1904,9 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(PROPHETNET_STANDALONE_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(PROPHETNET_STANDALONE_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=ProphetNetDecoderLMOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=ProphetNetDecoderLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -1459,6 +1459,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.rag.generator.get_output_embeddings()
|
return self.rag.generator.get_output_embeddings()
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
return self.rag.generator.set_output_embeddings(new_embeddings)
|
||||||
|
|
||||||
def shift_tokens_right(self, input_ids, start_token_id=None):
|
def shift_tokens_right(self, input_ids, start_token_id=None):
|
||||||
"""Shift input ids one token to the right, and pad with start_token_id"""
|
"""Shift input ids one token to the right, and pad with start_token_id"""
|
||||||
if start_token_id is None:
|
if start_token_id is None:
|
||||||
|
|||||||
@@ -2197,6 +2197,9 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head.decoder
|
return self.lm_head.decoder
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head.decoder = new_embeddings
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
@@ -2309,6 +2312,9 @@ class ReformerForMaskedLM(ReformerPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head.decoder
|
return self.lm_head.decoder
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head.decoder = new_embeddings
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
|
|||||||
@@ -752,6 +752,9 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head.decoder
|
return self.lm_head.decoder
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head.decoder = new_embeddings
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
@@ -873,6 +876,9 @@ class RobertaForMaskedLM(RobertaPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head.decoder
|
return self.lm_head.decoder
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head.decoder = new_embeddings
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
|
|||||||
@@ -655,6 +655,9 @@ class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.cls.predictions.decoder
|
return self.cls.predictions.decoder
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.cls.predictions.decoder = new_embeddings
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
@add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
|
|||||||
@@ -1363,6 +1363,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
self.encoder.set_input_embeddings(new_embeddings)
|
self.encoder.set_input_embeddings(new_embeddings)
|
||||||
self.decoder.set_input_embeddings(new_embeddings)
|
self.decoder.set_input_embeddings(new_embeddings)
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
|
|||||||
@@ -688,6 +688,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.pred_layer.proj
|
return self.pred_layer.proj
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.pred_layer.proj = new_embeddings
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||||
mask_token_id = self.config.mask_token_id
|
mask_token_id = self.config.mask_token_id
|
||||||
lang_id = self.config.lang_id
|
lang_id = self.config.lang_id
|
||||||
|
|||||||
@@ -1312,6 +1312,9 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_loss
|
return self.lm_loss
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_loss = new_embeddings
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past=None, use_mems=None, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past=None, use_mems=None, **kwargs):
|
||||||
# Add dummy token at the end (no attention on this one)
|
# Add dummy token at the end (no attention on this one)
|
||||||
|
|
||||||
|
|||||||
@@ -781,6 +781,9 @@ class {{cookiecutter.camelcase_modelname}}ForMaskedLM({{cookiecutter.camelcase_m
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.cls.predictions.decoder
|
return self.cls.predictions.decoder
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.cls.predictions.decoder = new_embeddings
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
|
|||||||
@@ -208,6 +208,10 @@ class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("TODO: Decoder embeddings cannot be resized at the moment")
|
||||||
|
def test_resize_embeddings_untied(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
def test_tiny_model(self):
|
def test_tiny_model(self):
|
||||||
|
|||||||
@@ -128,6 +128,10 @@ class BlenderbotTesterMixin(ModelTesterMixin, unittest.TestCase):
|
|||||||
def test_feed_forward_chunking(self):
|
def test_feed_forward_chunking(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("TODO: Decoder embeddings cannot be resized at the moment")
|
||||||
|
def test_resize_embeddings_untied(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipUnless(torch_device != "cpu", "3B test too slow on CPU.")
|
@unittest.skipUnless(torch_device != "cpu", "3B test too slow on CPU.")
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
@@ -815,6 +815,10 @@ class ModelTesterMixin:
|
|||||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
# Input ids should be clamped to the maximum size of the vocabulary
|
# Input ids should be clamped to the maximum size of the vocabulary
|
||||||
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
||||||
|
|
||||||
|
# make sure that decoder_input_ids are resized as well
|
||||||
|
if "decoder_input_ids" in inputs_dict:
|
||||||
|
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
||||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
|
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
|
||||||
@@ -825,6 +829,57 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
self.assertTrue(models_equal)
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
def test_resize_embeddings_untied(self):
|
||||||
|
(
|
||||||
|
original_config,
|
||||||
|
inputs_dict,
|
||||||
|
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
if not self.test_resize_embeddings:
|
||||||
|
return
|
||||||
|
|
||||||
|
original_config.tie_word_embeddings = False
|
||||||
|
|
||||||
|
# if model cannot untied embeddings -> leave test
|
||||||
|
if original_config.tie_word_embeddings:
|
||||||
|
return
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
config = copy.deepcopy(original_config)
|
||||||
|
model = model_class(config).to(torch_device)
|
||||||
|
|
||||||
|
# if no output embeddings -> leave test
|
||||||
|
if model.get_output_embeddings() is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||||
|
model_vocab_size = config.vocab_size
|
||||||
|
model.resize_token_embeddings(model_vocab_size + 10)
|
||||||
|
self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
|
||||||
|
output_embeds = model.get_output_embeddings()
|
||||||
|
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
|
||||||
|
# Check bias if present
|
||||||
|
if output_embeds.bias is not None:
|
||||||
|
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10)
|
||||||
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
||||||
|
model.resize_token_embeddings(model_vocab_size - 15)
|
||||||
|
self.assertEqual(model.config.vocab_size, model_vocab_size - 15)
|
||||||
|
# Check that it actually resizes the embeddings matrix
|
||||||
|
output_embeds = model.get_output_embeddings()
|
||||||
|
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15)
|
||||||
|
# Check bias if present
|
||||||
|
if output_embeds.bias is not None:
|
||||||
|
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15)
|
||||||
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
|
# Input ids should be clamped to the maximum size of the vocabulary
|
||||||
|
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
||||||
|
if "decoder_input_ids" in inputs_dict:
|
||||||
|
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
||||||
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
def test_model_common_attributes(self):
|
def test_model_common_attributes(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
@@ -226,15 +226,9 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
def test_tie_model_weights(self):
|
def test_tie_model_weights(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# def test_auto_model(self):
|
@unittest.skip("TODO: Decoder embeddings cannot be resized at the moment")
|
||||||
# # XXX: add a tiny model to s3?
|
def test_resize_embeddings_untied(self):
|
||||||
# model_name = "facebook/wmt19-ru-en-tiny"
|
pass
|
||||||
# tiny = AutoModel.from_pretrained(model_name) # same vocab size
|
|
||||||
# tok = AutoTokenizer.from_pretrained(model_name) # same tokenizer
|
|
||||||
# inputs_dict = tok.batch_encode_plus(["Hello my friends"], return_tensors="pt")
|
|
||||||
|
|
||||||
# with torch.no_grad():
|
|
||||||
# tiny(**inputs_dict)
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
@@ -574,6 +574,10 @@ class ReformerTesterMixin:
|
|||||||
# reformer cannot keep gradients in attentions or hidden states
|
# reformer cannot keep gradients in attentions or hidden states
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def test_resize_embeddings_untied(self):
|
||||||
|
# reformer cannot resize embeddings that easily
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, ModelTesterMixin, unittest.TestCase):
|
class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, ModelTesterMixin, unittest.TestCase):
|
||||||
|
|||||||
@@ -444,6 +444,20 @@ class T5ModelTester:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def check_resize_embeddings_t5_v1_1(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
):
|
||||||
|
prev_vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
config.tie_word_embeddings = False
|
||||||
|
model = T5ForConditionalGeneration(config=config).to(torch_device).eval()
|
||||||
|
model.resize_token_embeddings(prev_vocab_size - 10)
|
||||||
|
|
||||||
|
self.parent.assertEqual(model.get_input_embeddings().weight.shape[0], prev_vocab_size - 10)
|
||||||
|
self.parent.assertEqual(model.get_output_embeddings().weight.shape[0], prev_vocab_size - 10)
|
||||||
|
self.parent.assertEqual(model.config.vocab_size, prev_vocab_size - 10)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(
|
(
|
||||||
@@ -480,7 +494,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_torchscript = True
|
test_torchscript = True
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = True
|
||||||
test_model_parallel = True
|
test_model_parallel = True
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
|
|
||||||
@@ -536,6 +550,10 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
|
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_v1_1_resize_embeddings(self):
|
||||||
|
config = self.model_tester.prepare_config_and_inputs()[0]
|
||||||
|
self.model_tester.check_resize_embeddings_t5_v1_1(config)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in T5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in T5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
|||||||
@@ -299,6 +299,10 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
|
|||||||
self.assertEqual(model_vocab_size, model.config.vocab_size)
|
self.assertEqual(model_vocab_size, model.config.vocab_size)
|
||||||
self.assertEqual(model_embed.emb_layers[layer].weight.shape[0], cloned_embeddings[layer].shape[0])
|
self.assertEqual(model_embed.emb_layers[layer].weight.shape[0], cloned_embeddings[layer].shape[0])
|
||||||
|
|
||||||
|
def test_resize_embeddings_untied(self):
|
||||||
|
# transfo-xl requires special resize for lm-head
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user