Config: unified logic to retrieve text config (#33219)
This commit is contained in:
@@ -1747,12 +1747,13 @@ class ModelTesterMixin:
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def test_resize_tokens_embeddings(self):
|
||||
if not self.test_resize_embeddings:
|
||||
self.skipTest(reason="test_resize_embeddings is set to `False`")
|
||||
|
||||
(
|
||||
original_config,
|
||||
inputs_dict,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.test_resize_embeddings:
|
||||
self.skipTest(reason="test_resize_embeddings is set to `False`")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config = copy.deepcopy(original_config)
|
||||
@@ -1764,18 +1765,15 @@ class ModelTesterMixin:
|
||||
if self.model_tester.is_training is False:
|
||||
model.eval()
|
||||
|
||||
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
|
||||
model_vocab_size = config.get_text_config().vocab_size
|
||||
# Retrieve the embeddings and clone theme
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size)
|
||||
cloned_embeddings = model_embed.weight.clone()
|
||||
|
||||
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
|
||||
new_model_vocab_size = (
|
||||
model.config.text_config.vocab_size
|
||||
if hasattr(model.config, "text_config")
|
||||
else model.config.vocab_size
|
||||
)
|
||||
new_model_vocab_size = model.config.get_text_config().vocab_size
|
||||
|
||||
self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
|
||||
# Check that it actually resizes the embeddings matrix
|
||||
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
|
||||
@@ -1787,11 +1785,7 @@ class ModelTesterMixin:
|
||||
|
||||
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size - 15)
|
||||
new_model_vocab_size = (
|
||||
model.config.text_config.vocab_size
|
||||
if hasattr(model.config, "text_config")
|
||||
else model.config.vocab_size
|
||||
)
|
||||
new_model_vocab_size = model.config.get_text_config().vocab_size
|
||||
self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
|
||||
# Check that it actually resizes the embeddings matrix
|
||||
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
|
||||
@@ -1817,21 +1811,13 @@ class ModelTesterMixin:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
|
||||
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
|
||||
model_vocab_size = config.get_text_config().vocab_size
|
||||
model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1)
|
||||
new_model_vocab_size = (
|
||||
model.config.text_config.vocab_size
|
||||
if hasattr(model.config, "text_config")
|
||||
else model.config.vocab_size
|
||||
)
|
||||
new_model_vocab_size = model.config.get_text_config().vocab_size
|
||||
self.assertTrue(new_model_vocab_size + 10, model_vocab_size)
|
||||
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64)
|
||||
new_model_vocab_size = (
|
||||
model.config.text_config.vocab_size
|
||||
if hasattr(model.config, "text_config")
|
||||
else model.config.vocab_size
|
||||
)
|
||||
new_model_vocab_size = model.config.get_text_config().vocab_size
|
||||
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
|
||||
|
||||
self.assertTrue(model_embed.weight.shape[0], new_model_vocab_size)
|
||||
@@ -1852,13 +1838,10 @@ class ModelTesterMixin:
|
||||
model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3)
|
||||
|
||||
def test_resize_embeddings_untied(self):
|
||||
(
|
||||
original_config,
|
||||
inputs_dict,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.test_resize_embeddings:
|
||||
self.skipTest(reason="test_resize_embeddings is set to `False`")
|
||||
|
||||
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
original_config.tie_word_embeddings = False
|
||||
|
||||
# if model cannot untied embeddings -> leave test
|
||||
@@ -1874,13 +1857,9 @@ class ModelTesterMixin:
|
||||
continue
|
||||
|
||||
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
|
||||
model_vocab_size = config.get_text_config().vocab_size
|
||||
model.resize_token_embeddings(model_vocab_size + 10)
|
||||
new_model_vocab_size = (
|
||||
model.config.text_config.vocab_size
|
||||
if hasattr(model.config, "text_config")
|
||||
else model.config.vocab_size
|
||||
)
|
||||
new_model_vocab_size = model.config.get_text_config().vocab_size
|
||||
self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
|
||||
output_embeds = model.get_output_embeddings()
|
||||
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
|
||||
@@ -1892,11 +1871,7 @@ class ModelTesterMixin:
|
||||
|
||||
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
||||
model.resize_token_embeddings(model_vocab_size - 15)
|
||||
new_model_vocab_size = (
|
||||
model.config.text_config.vocab_size
|
||||
if hasattr(model.config, "text_config")
|
||||
else model.config.vocab_size
|
||||
)
|
||||
new_model_vocab_size = model.config.get_text_config().vocab_size
|
||||
self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
|
||||
# Check that it actually resizes the embeddings matrix
|
||||
output_embeds = model.get_output_embeddings()
|
||||
@@ -1988,7 +1963,7 @@ class ModelTesterMixin:
|
||||
# self.assertTrue(check_same_values(embeddings, decoding))
|
||||
|
||||
# Check that after resize they remain tied.
|
||||
vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
|
||||
vocab_size = config.get_text_config().vocab_size
|
||||
model_tied.resize_token_embeddings(vocab_size + 10)
|
||||
params_tied_2 = list(model_tied.parameters())
|
||||
self.assertEqual(len(params_tied_2), len(params_tied))
|
||||
@@ -4831,7 +4806,7 @@ class ModelTesterMixin:
|
||||
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
batch_size, sequence_length = inputs["input_ids"].shape
|
||||
vocab_size = config.vocab_size
|
||||
vocab_size = config.get_text_config().vocab_size
|
||||
model = model_class(config).to(device=torch_device).eval()
|
||||
|
||||
# num_logits_to_keep=0 is a special case meaning "keep all logits"
|
||||
|
||||
Reference in New Issue
Block a user