From b4727a1216bb21df2795e973063ed07202235d7e Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Thu, 1 Aug 2024 17:32:13 +0500 Subject: [PATCH] Fix conflicting key in init kwargs in PreTrainedTokenizerBase (#31233) * Fix conflicting key in init kwargs in PreTrainedTokenizerBase * Update code to check for callable key in save_pretrained * Apply PR suggestions * Invoke CI * Updates based on PR suggestion --- src/transformers/tokenization_utils_base.py | 4 ++++ tests/test_tokenization_common.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 60c52633a7..80f023f216 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1569,6 +1569,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): def __init__(self, **kwargs): # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``) self.init_inputs = () + for key in kwargs: + if hasattr(self, key) and callable(getattr(self, key)): + raise AttributeError(f"{key} conflicts with the method {key} in {self.__class__.__name__}") + self.init_kwargs = copy.deepcopy(kwargs) self.name_or_path = kwargs.pop("name_or_path", "") self._processor_class = kwargs.pop("processor_class", None) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 021557c1a5..d8ff702cbe 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -4408,3 +4408,11 @@ class TokenizerTesterMixin: replace_additional_special_tokens=False, ) self.assertEqual(tokenizer_2.additional_special_tokens, ["", "", ""]) + + def test_tokenizer_initialization_with_conflicting_key(self): + get_tokenizer_func = self.get_rust_tokenizer if self.test_rust_tokenizer else self.get_tokenizer + with self.assertRaises(AttributeError, msg="conflicts with the method"): + get_tokenizer_func(add_special_tokens=True) + + with self.assertRaises(AttributeError, msg="conflicts with the method"): + get_tokenizer_func(get_vocab=True)