[FIX] resize_token_embeddings (#26102)
* fix roundup command * add test for resize_token_embeddings * Update tests/test_modeling_common.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * style --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -1530,7 +1530,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
)
|
)
|
||||||
if new_num_tokens is None:
|
if new_num_tokens is None:
|
||||||
new_num_tokens = old_embeddings.weight.shape[0]
|
new_num_tokens = old_embeddings.weight.shape[0]
|
||||||
new_num_tokens = ((new_num_tokens // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
new_num_tokens = ((new_num_tokens + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding"
|
"You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding"
|
||||||
|
|||||||
@@ -1437,6 +1437,11 @@ class ModelTesterMixin:
|
|||||||
model_embed = model.resize_token_embeddings(model_vocab_size + 13, pad_to_multiple_of=64)
|
model_embed = model.resize_token_embeddings(model_vocab_size + 13, pad_to_multiple_of=64)
|
||||||
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
|
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
|
||||||
|
|
||||||
|
# Check that resizing a model to a multiple of pad_to_multiple leads to a model of exactly that size
|
||||||
|
target_dimension = 128
|
||||||
|
model_embed = model.resize_token_embeddings(target_dimension, pad_to_multiple_of=64)
|
||||||
|
self.assertTrue(model_embed.weight.shape[0], target_dimension)
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError,
|
ValueError,
|
||||||
"Asking to pad the embedding matrix to a multiple of `1.3`, which is not and integer. Please make sure to pass an integer",
|
"Asking to pad the embedding matrix to a multiple of `1.3`, which is not and integer. Please make sure to pass an integer",
|
||||||
|
|||||||
Reference in New Issue
Block a user