[PyTorch] Refactor Resize Token Embeddings (#8880)
* fix resize tokens
* correct mobile_bert
* move embedding fix into modeling_utils.py
* refactor
* fix lm head resize
* refactor
* break lines to make sylvain happy
* add news tests
* fix typo
* improve test
* skip bart-like for now
* check if base_model = get(...) is necessary
* clean files
* improve test
* fix tests
* revert style templates
* Update templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
This commit is contained in:
committed by
GitHub
parent
e52f9c0ade
commit
443f67e887
@@ -422,6 +422,9 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head.decoder
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head.decoder = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user