[PyTorch] Refactor Resize Token Embeddings (#8880)

* fix resize tokens

* correct mobile_bert

* move embedding fix into modeling_utils.py

* refactor

* fix lm head resize

* refactor

* break lines to make sylvain happy

* add news tests

* fix typo

* improve test

* skip bart-like for now

* check if base_model = get(...) is necessary

* clean files

* improve test

* fix tests

* revert style templates

* Update templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
This commit is contained in:
Patrick von Platen
2020-12-02 19:19:50 +01:00
committed by GitHub
parent e52f9c0ade
commit 443f67e887
30 changed files with 273 additions and 57 deletions

View File

@@ -444,6 +444,20 @@ class T5ModelTester:
)
)
def check_resize_embeddings_t5_v1_1(
self,
config,
):
prev_vocab_size = config.vocab_size
config.tie_word_embeddings = False
model = T5ForConditionalGeneration(config=config).to(torch_device).eval()
model.resize_token_embeddings(prev_vocab_size - 10)
self.parent.assertEqual(model.get_input_embeddings().weight.shape[0], prev_vocab_size - 10)
self.parent.assertEqual(model.get_output_embeddings().weight.shape[0], prev_vocab_size - 10)
self.parent.assertEqual(model.config.vocab_size, prev_vocab_size - 10)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@@ -480,7 +494,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
)
test_pruning = False
test_torchscript = True
test_resize_embeddings = False
test_resize_embeddings = True
test_model_parallel = True
is_encoder_decoder = True
@@ -536,6 +550,10 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
def test_v1_1_resize_embeddings(self):
config = self.model_tester.prepare_config_and_inputs()[0]
self.model_tester.check_resize_embeddings_t5_v1_1(config)
@slow
def test_model_from_pretrained(self):
for model_name in T5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: