Use text config's vocab size in testing models (#30568)
use text config's vocab size
This commit is contained in:
committed by
GitHub
parent
78fdd64dcf
commit
9d31b32e9d
@@ -1762,14 +1762,19 @@ class ModelTesterMixin:
|
||||
if self.model_tester.is_training is False:
|
||||
model.eval()
|
||||
|
||||
model_vocab_size = config.vocab_size
|
||||
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
|
||||
# Retrieve the embeddings and clone theme
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size)
|
||||
cloned_embeddings = model_embed.weight.clone()
|
||||
|
||||
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
|
||||
self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
|
||||
new_model_vocab_size = (
|
||||
model.config.text_config.vocab_size
|
||||
if hasattr(model.config, "text_config")
|
||||
else model.config.vocab_size
|
||||
)
|
||||
self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
|
||||
# Check that it actually resizes the embeddings matrix
|
||||
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
@@ -1777,7 +1782,12 @@ class ModelTesterMixin:
|
||||
|
||||
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size - 15)
|
||||
self.assertEqual(model.config.vocab_size, model_vocab_size - 15)
|
||||
new_model_vocab_size = (
|
||||
model.config.text_config.vocab_size
|
||||
if hasattr(model.config, "text_config")
|
||||
else model.config.vocab_size
|
||||
)
|
||||
self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
|
||||
# Check that it actually resizes the embeddings matrix
|
||||
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
|
||||
|
||||
@@ -1802,15 +1812,25 @@ class ModelTesterMixin:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
|
||||
model_vocab_size = config.vocab_size
|
||||
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
|
||||
model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1)
|
||||
self.assertTrue(model.config.vocab_size + 10, model_vocab_size)
|
||||
new_model_vocab_size = (
|
||||
model.config.text_config.vocab_size
|
||||
if hasattr(model.config, "text_config")
|
||||
else model.config.vocab_size
|
||||
)
|
||||
self.assertTrue(new_model_vocab_size + 10, model_vocab_size)
|
||||
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64)
|
||||
new_model_vocab_size = (
|
||||
model.config.text_config.vocab_size
|
||||
if hasattr(model.config, "text_config")
|
||||
else model.config.vocab_size
|
||||
)
|
||||
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
|
||||
|
||||
self.assertTrue(model_embed.weight.shape[0], model.config.vocab_size)
|
||||
self.assertTrue(model.config.vocab_size, model.vocab_size)
|
||||
self.assertTrue(model_embed.weight.shape[0], new_model_vocab_size)
|
||||
self.assertTrue(new_model_vocab_size, model.vocab_size)
|
||||
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size + 13, pad_to_multiple_of=64)
|
||||
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
|
||||
@@ -1849,9 +1869,14 @@ class ModelTesterMixin:
|
||||
continue
|
||||
|
||||
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||
model_vocab_size = config.vocab_size
|
||||
model_vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
|
||||
model.resize_token_embeddings(model_vocab_size + 10)
|
||||
self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
|
||||
new_model_vocab_size = (
|
||||
model.config.text_config.vocab_size
|
||||
if hasattr(model.config, "text_config")
|
||||
else model.config.vocab_size
|
||||
)
|
||||
self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
|
||||
output_embeds = model.get_output_embeddings()
|
||||
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
|
||||
# Check bias if present
|
||||
@@ -1862,7 +1887,12 @@ class ModelTesterMixin:
|
||||
|
||||
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
||||
model.resize_token_embeddings(model_vocab_size - 15)
|
||||
self.assertEqual(model.config.vocab_size, model_vocab_size - 15)
|
||||
new_model_vocab_size = (
|
||||
model.config.text_config.vocab_size
|
||||
if hasattr(model.config, "text_config")
|
||||
else model.config.vocab_size
|
||||
)
|
||||
self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
|
||||
# Check that it actually resizes the embeddings matrix
|
||||
output_embeds = model.get_output_embeddings()
|
||||
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15)
|
||||
@@ -1949,7 +1979,8 @@ class ModelTesterMixin:
|
||||
# self.assertTrue(check_same_values(embeddings, decoding))
|
||||
|
||||
# Check that after resize they remain tied.
|
||||
model_tied.resize_token_embeddings(config.vocab_size + 10)
|
||||
vocab_size = config.text_config.vocab_size if hasattr(config, "text_config") else config.vocab_size
|
||||
model_tied.resize_token_embeddings(vocab_size + 10)
|
||||
params_tied_2 = list(model_tied.parameters())
|
||||
self.assertEqual(len(params_tied_2), len(params_tied))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user