From 6d585fe0f0921b86c05723d40d76dce91ba51165 Mon Sep 17 00:00:00 2001 From: SaulLu <55560583+SaulLu@users.noreply.github.com> Date: Tue, 1 Feb 2022 16:13:58 +0100 Subject: [PATCH] replace assert with exception for padding_side arg in `PreTrainedTokenizerBase` `__init__` (#15454) * replace assert with exception for `padding_side` arg in `PreTrainedTokenizerBase` `__init__` * add test * fix kwargs * reformat test * format * format * fix typo to render the documentation --- src/transformers/tokenization_utils_base.py | 10 +++--- tests/test_tokenization_common.py | 37 +++++++++++++++++++++ 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 8389e7a6cf..ebd83e4214 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1389,7 +1389,7 @@ INIT_TOKENIZER_DOCSTRING = r""" loaded with [`~tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained`], this will be set to the value stored for the associated model in `max_model_input_sizes` (see above). If no value is provided, will default to VERY_LARGE_INTEGER (`int(1e30)`). - padding_side: (`str`, *optional*): + padding_side (`str`, *optional*): The side on which the model should have padding applied. Should be selected between ['right', 'left']. Default value is picked from the class attribute of the same name. model_input_names (`List[string]`, *optional*): @@ -1456,10 +1456,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): # Padding side is right by default and overridden in subclasses. If specified in the kwargs, it is changed. self.padding_side = kwargs.pop("padding_side", self.padding_side) - assert self.padding_side in [ - "right", - "left", - ], f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}" + if self.padding_side not in ["right", "left"]: + raise ValueError( + f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}" + ) self.model_input_names = kwargs.pop("model_input_names", self.model_input_names) self.deprecation_warnings = ( diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index bee7ee7209..0cfbeb7f53 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1367,6 +1367,43 @@ class TokenizerTesterMixin: filtered_sequence = [x for x in filtered_sequence if x is not None] self.assertEqual(encoded_sequence, filtered_sequence) + def test_padding_side_in_kwargs(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + if self.test_rust_tokenizer: + tokenizer_r = self.rust_tokenizer_class.from_pretrained( + pretrained_name, padding_side="left", **kwargs + ) + self.assertEqual(tokenizer_r.padding_side, "left") + + tokenizer_r = self.rust_tokenizer_class.from_pretrained( + pretrained_name, padding_side="right", **kwargs + ) + self.assertEqual(tokenizer_r.padding_side, "right") + + self.assertRaises( + ValueError, + self.rust_tokenizer_class.from_pretrained, + pretrained_name, + padding_side="unauthorized", + **kwargs, + ) + + if self.test_slow_tokenizer: + tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, padding_side="left", **kwargs) + self.assertEqual(tokenizer_p.padding_side, "left") + + tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, padding_side="right", **kwargs) + self.assertEqual(tokenizer_p.padding_side, "right") + + self.assertRaises( + ValueError, + self.tokenizer_class.from_pretrained, + pretrained_name, + padding_side="unauthorized", + **kwargs, + ) + def test_right_and_left_padding(self): tokenizers = self.get_tokenizers(do_lower_case=False) for tokenizer in tokenizers: