[PyTorch] Refactor Resize Token Embeddings (#8880)

* fix resize tokens

* correct mobile_bert

* move embedding fix into modeling_utils.py

* refactor

* fix lm head resize

* refactor

* break lines to make sylvain happy

* add news tests

* fix typo

* improve test

* skip bart-like for now

* check if base_model = get(...) is necessary

* clean files

* improve test

* fix tests

* revert style templates

* Update templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
This commit is contained in:
Patrick von Platen
2020-12-02 19:19:50 +01:00
committed by GitHub
parent e52f9c0ade
commit 443f67e887
30 changed files with 273 additions and 57 deletions

View File

@@ -226,15 +226,9 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
def test_tie_model_weights(self):
pass
# def test_auto_model(self):
# # XXX: add a tiny model to s3?
# model_name = "facebook/wmt19-ru-en-tiny"
# tiny = AutoModel.from_pretrained(model_name) # same vocab size
# tok = AutoTokenizer.from_pretrained(model_name) # same tokenizer
# inputs_dict = tok.batch_encode_plus(["Hello my friends"], return_tensors="pt")
# with torch.no_grad():
# tiny(**inputs_dict)
@unittest.skip("TODO: Decoder embeddings cannot be resized at the moment")
def test_resize_embeddings_untied(self):
pass
@require_torch