From 8e3980a290acc6d2f8ea76dba111b9ef0ef00309 Mon Sep 17 00:00:00 2001 From: Sam Passaglia <8333102+passaglia@users.noreply.github.com> Date: Wed, 20 Sep 2023 04:44:41 +0900 Subject: [PATCH] [FIX] resize_token_embeddings (#26102) * fix roundup command * add test for resize_token_embeddings * Update tests/test_modeling_common.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * style --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/modeling_utils.py | 2 +- tests/test_modeling_common.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 06170958ed..057caaedba 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1530,7 +1530,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) if new_num_tokens is None: new_num_tokens = old_embeddings.weight.shape[0] - new_num_tokens = ((new_num_tokens // pad_to_multiple_of) + 1) * pad_to_multiple_of + new_num_tokens = ((new_num_tokens + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of else: logger.warning( "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding" diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 6e144ed476..764554b436 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1437,6 +1437,11 @@ class ModelTesterMixin: model_embed = model.resize_token_embeddings(model_vocab_size + 13, pad_to_multiple_of=64) self.assertTrue(model_embed.weight.shape[0] // 64, 0) + # Check that resizing a model to a multiple of pad_to_multiple leads to a model of exactly that size + target_dimension = 128 + model_embed = model.resize_token_embeddings(target_dimension, pad_to_multiple_of=64) + self.assertTrue(model_embed.weight.shape[0], target_dimension) + with self.assertRaisesRegex( ValueError, "Asking to pad the embedding matrix to a multiple of `1.3`, which is not and integer. Please make sure to pass an integer",