[PyTorch] Refactor Resize Token Embeddings (#8880)
* fix resize tokens
* correct mobile_bert
* move embedding fix into modeling_utils.py
* refactor
* fix lm head resize
* refactor
* break lines to make sylvain happy
* add news tests
* fix typo
* improve test
* skip bart-like for now
* check if base_model = get(...) is necessary
* clean files
* improve test
* fix tests
* revert style templates
* Update templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
This commit is contained in:
committed by
GitHub
parent
e52f9c0ade
commit
443f67e887
@@ -208,6 +208,10 @@ class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO: Decoder embeddings cannot be resized at the moment")
|
||||
def test_resize_embeddings_untied(self):
|
||||
pass
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
def test_tiny_model(self):
|
||||
|
||||
@@ -128,6 +128,10 @@ class BlenderbotTesterMixin(ModelTesterMixin, unittest.TestCase):
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO: Decoder embeddings cannot be resized at the moment")
|
||||
def test_resize_embeddings_untied(self):
|
||||
pass
|
||||
|
||||
|
||||
@unittest.skipUnless(torch_device != "cpu", "3B test too slow on CPU.")
|
||||
@require_torch
|
||||
|
||||
@@ -815,6 +815,10 @@ class ModelTesterMixin:
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
# Input ids should be clamped to the maximum size of the vocabulary
|
||||
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
||||
|
||||
# make sure that decoder_input_ids are resized as well
|
||||
if "decoder_input_ids" in inputs_dict:
|
||||
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
|
||||
@@ -825,6 +829,57 @@ class ModelTesterMixin:
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def test_resize_embeddings_untied(self):
|
||||
(
|
||||
original_config,
|
||||
inputs_dict,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.test_resize_embeddings:
|
||||
return
|
||||
|
||||
original_config.tie_word_embeddings = False
|
||||
|
||||
# if model cannot untied embeddings -> leave test
|
||||
if original_config.tie_word_embeddings:
|
||||
return
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config = copy.deepcopy(original_config)
|
||||
model = model_class(config).to(torch_device)
|
||||
|
||||
# if no output embeddings -> leave test
|
||||
if model.get_output_embeddings() is None:
|
||||
continue
|
||||
|
||||
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||
model_vocab_size = config.vocab_size
|
||||
model.resize_token_embeddings(model_vocab_size + 10)
|
||||
self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
|
||||
output_embeds = model.get_output_embeddings()
|
||||
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
|
||||
# Check bias if present
|
||||
if output_embeds.bias is not None:
|
||||
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10)
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
||||
model.resize_token_embeddings(model_vocab_size - 15)
|
||||
self.assertEqual(model.config.vocab_size, model_vocab_size - 15)
|
||||
# Check that it actually resizes the embeddings matrix
|
||||
output_embeds = model.get_output_embeddings()
|
||||
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15)
|
||||
# Check bias if present
|
||||
if output_embeds.bias is not None:
|
||||
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15)
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
# Input ids should be clamped to the maximum size of the vocabulary
|
||||
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
||||
if "decoder_input_ids" in inputs_dict:
|
||||
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
@@ -226,15 +226,9 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
def test_tie_model_weights(self):
|
||||
pass
|
||||
|
||||
# def test_auto_model(self):
|
||||
# # XXX: add a tiny model to s3?
|
||||
# model_name = "facebook/wmt19-ru-en-tiny"
|
||||
# tiny = AutoModel.from_pretrained(model_name) # same vocab size
|
||||
# tok = AutoTokenizer.from_pretrained(model_name) # same tokenizer
|
||||
# inputs_dict = tok.batch_encode_plus(["Hello my friends"], return_tensors="pt")
|
||||
|
||||
# with torch.no_grad():
|
||||
# tiny(**inputs_dict)
|
||||
@unittest.skip("TODO: Decoder embeddings cannot be resized at the moment")
|
||||
def test_resize_embeddings_untied(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
|
||||
@@ -574,6 +574,10 @@ class ReformerTesterMixin:
|
||||
# reformer cannot keep gradients in attentions or hidden states
|
||||
return
|
||||
|
||||
def test_resize_embeddings_untied(self):
|
||||
# reformer cannot resize embeddings that easily
|
||||
return
|
||||
|
||||
|
||||
@require_torch
|
||||
class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@@ -444,6 +444,20 @@ class T5ModelTester:
|
||||
)
|
||||
)
|
||||
|
||||
def check_resize_embeddings_t5_v1_1(
|
||||
self,
|
||||
config,
|
||||
):
|
||||
prev_vocab_size = config.vocab_size
|
||||
|
||||
config.tie_word_embeddings = False
|
||||
model = T5ForConditionalGeneration(config=config).to(torch_device).eval()
|
||||
model.resize_token_embeddings(prev_vocab_size - 10)
|
||||
|
||||
self.parent.assertEqual(model.get_input_embeddings().weight.shape[0], prev_vocab_size - 10)
|
||||
self.parent.assertEqual(model.get_output_embeddings().weight.shape[0], prev_vocab_size - 10)
|
||||
self.parent.assertEqual(model.config.vocab_size, prev_vocab_size - 10)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@@ -480,7 +494,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
)
|
||||
test_pruning = False
|
||||
test_torchscript = True
|
||||
test_resize_embeddings = False
|
||||
test_resize_embeddings = True
|
||||
test_model_parallel = True
|
||||
is_encoder_decoder = True
|
||||
|
||||
@@ -536,6 +550,10 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
|
||||
|
||||
def test_v1_1_resize_embeddings(self):
|
||||
config = self.model_tester.prepare_config_and_inputs()[0]
|
||||
self.model_tester.check_resize_embeddings_t5_v1_1(config)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in T5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
||||
@@ -299,6 +299,10 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
|
||||
self.assertEqual(model_vocab_size, model.config.vocab_size)
|
||||
self.assertEqual(model_embed.emb_layers[layer].weight.shape[0], cloned_embeddings[layer].shape[0])
|
||||
|
||||
def test_resize_embeddings_untied(self):
|
||||
# transfo-xl requires special resize for lm-head
|
||||
return
|
||||
|
||||
|
||||
@require_torch
|
||||
class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user