Config: unified logic to retrieve text config (#33219)

This commit is contained in:
Joao Gante
2024-09-04 12:03:30 +01:00
committed by GitHub
parent ebbe8d8014
commit d750b509fc
10 changed files with 91 additions and 88 deletions

View File

@@ -1747,12 +1747,13 @@ class ModelTesterMixin:
self.assertTrue(models_equal)
def test_resize_tokens_embeddings(self):
if not self.test_resize_embeddings:
self.skipTest(reason="test_resize_embeddings is set to `False`")
(
original_config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:
self.skipTest(reason="test_resize_embeddings is set to `False`")
for model_class in self.all_model_classes:
config = copy.deepcopy(original_config)
@@ -1764,18 +1765,15 @@ class ModelTesterMixin:
if self.model_tester.is_training is False:
model.eval()
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
model_vocab_size = config.get_text_config().vocab_size
# Retrieve the embeddings and clone theme
model_embed = model.resize_token_embeddings(model_vocab_size)
cloned_embeddings = model_embed.weight.clone()
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
new_model_vocab_size = (
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
new_model_vocab_size = model.config.get_text_config().vocab_size
self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
# Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
@@ -1787,11 +1785,7 @@ class ModelTesterMixin:
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size - 15)
new_model_vocab_size = (
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
new_model_vocab_size = model.config.get_text_config().vocab_size
self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
# Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
@@ -1817,21 +1811,13 @@ class ModelTesterMixin:
model = model_class(config)
model.to(torch_device)
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
model_vocab_size = config.get_text_config().vocab_size
model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1)
new_model_vocab_size = (
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
new_model_vocab_size = model.config.get_text_config().vocab_size
self.assertTrue(new_model_vocab_size + 10, model_vocab_size)
model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64)
new_model_vocab_size = (
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
new_model_vocab_size = model.config.get_text_config().vocab_size
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
self.assertTrue(model_embed.weight.shape[0], new_model_vocab_size)
@@ -1852,13 +1838,10 @@ class ModelTesterMixin:
model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3)
def test_resize_embeddings_untied(self):
(
original_config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:
self.skipTest(reason="test_resize_embeddings is set to `False`")
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
original_config.tie_word_embeddings = False
# if model cannot untied embeddings -> leave test
@@ -1874,13 +1857,9 @@ class ModelTesterMixin:
continue
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
model_vocab_size = config.get_text_config().vocab_size
model.resize_token_embeddings(model_vocab_size + 10)
new_model_vocab_size = (
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
new_model_vocab_size = model.config.get_text_config().vocab_size
self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
output_embeds = model.get_output_embeddings()
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
@@ -1892,11 +1871,7 @@ class ModelTesterMixin:
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model.resize_token_embeddings(model_vocab_size - 15)
new_model_vocab_size = (
model.config.text_config.vocab_size
if hasattr(model.config, "text_config")
else model.config.vocab_size
)
new_model_vocab_size = model.config.get_text_config().vocab_size
self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
# Check that it actually resizes the embeddings matrix
output_embeds = model.get_output_embeddings()
@@ -1988,7 +1963,7 @@ class ModelTesterMixin:
# self.assertTrue(check_same_values(embeddings, decoding))
# Check that after resize they remain tied.
vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
vocab_size = config.get_text_config().vocab_size
model_tied.resize_token_embeddings(vocab_size + 10)
params_tied_2 = list(model_tied.parameters())
self.assertEqual(len(params_tied_2), len(params_tied))
@@ -4831,7 +4806,7 @@ class ModelTesterMixin:
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
batch_size, sequence_length = inputs["input_ids"].shape
vocab_size = config.vocab_size
vocab_size = config.get_text_config().vocab_size
model = model_class(config).to(device=torch_device).eval()
# num_logits_to_keep=0 is a special case meaning "keep all logits"