[PyTorch] Refactor Resize Token Embeddings (#8880)
* fix resize tokens
* correct mobile_bert
* move embedding fix into modeling_utils.py
* refactor
* fix lm head resize
* refactor
* break lines to make sylvain happy
* add news tests
* fix typo
* improve test
* skip bart-like for now
* check if base_model = get(...) is necessary
* clean files
* improve test
* fix tests
* revert style templates
* Update templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
This commit is contained in:
committed by
GitHub
parent
e52f9c0ade
commit
443f67e887
@@ -226,15 +226,9 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
def test_tie_model_weights(self):
|
||||
pass
|
||||
|
||||
# def test_auto_model(self):
|
||||
# # XXX: add a tiny model to s3?
|
||||
# model_name = "facebook/wmt19-ru-en-tiny"
|
||||
# tiny = AutoModel.from_pretrained(model_name) # same vocab size
|
||||
# tok = AutoTokenizer.from_pretrained(model_name) # same tokenizer
|
||||
# inputs_dict = tok.batch_encode_plus(["Hello my friends"], return_tensors="pt")
|
||||
|
||||
# with torch.no_grad():
|
||||
# tiny(**inputs_dict)
|
||||
@unittest.skip("TODO: Decoder embeddings cannot be resized at the moment")
|
||||
def test_resize_embeddings_untied(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user