From 39b5d1a63a07d60e496a6bd98c3a60d32e8b9e6d Mon Sep 17 00:00:00 2001 From: SaulLu <55560583+SaulLu@users.noreply.github.com> Date: Wed, 2 Feb 2022 23:18:09 +0100 Subject: [PATCH] fix set truncation attribute in `__init__` of `PreTrainedTokenizerBase` (#15456) * change truncation_side in init of `PreTrainedTokenizerBase` Co-authored-by: LSinev * add test * Revert "replace assert with exception for `padding_side` arg in `PreTrainedTokenizerBase` `__init__`" This reverts commit 7a98b87962d2635c7e4d4f00db3948b694624843. * fix kwargs * Revert "fix kwargs" This reverts commit 67b0a5270e8cf1dbf70e6b0232e94c0452b6946f. * Update tests/test_tokenization_common.py Co-authored-by: Nicolas Patry * delete truncation_side variable * reorganize test * format * complete doc * Revert "Revert "replace assert with exception for `padding_side` arg in `PreTrainedTokenizerBase` `__init__`"" This reverts commit d5a10a7e2680539e5d9e98ae5d896c893d224b80. * fix typo * fix typos to render documentation * Revert "Revert "Revert "replace assert with exception for `padding_side` arg in `PreTrainedTokenizerBase` `__init__`""" This reverts commit 16cf58811943a08f43409a7c83eaa330686591d0. * format Co-authored-by: LSinev Co-authored-by: Nicolas Patry --- src/transformers/tokenization_utils_base.py | 15 +++++++- tests/test_tokenization_common.py | 41 +++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 802681550b..e60d4331ed 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1383,6 +1383,8 @@ INIT_TOKENIZER_DOCSTRING = r""" - **model_input_names** (`List[str]`) -- A list of inputs expected in the forward pass of the model. - **padding_side** (`str`) -- The default value for the side on which the model should have padding applied. Should be `'right'` or `'left'`. + - **truncation_side** (`str`) -- The default value for the side on which the model should have truncation + applied. Should be `'right'` or `'left'`. Args: model_max_length (`int`, *optional*): @@ -1393,6 +1395,9 @@ INIT_TOKENIZER_DOCSTRING = r""" padding_side (`str`, *optional*): The side on which the model should have padding applied. Should be selected between ['right', 'left']. Default value is picked from the class attribute of the same name. + truncation_side (`str`, *optional*): + The side on which the model should have truncation applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. model_input_names (`List[string]`, *optional*): The list of inputs accepted by the forward pass of the model (like `"token_type_ids"` or `"attention_mask"`). Default value is picked from the class attribute of the same name. @@ -1456,12 +1461,20 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): model_max_length = kwargs.pop("model_max_length", kwargs.pop("max_len", None)) self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER - # Padding side is right by default and overridden in subclasses. If specified in the kwargs, it is changed. + # Padding and truncation side are right by default and overridden in subclasses. If specified in the kwargs, it + # is changed. self.padding_side = kwargs.pop("padding_side", self.padding_side) if self.padding_side not in ["right", "left"]: raise ValueError( f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}" ) + + self.truncation_side = kwargs.pop("truncation_side", self.truncation_side) + if self.truncation_side not in ["right", "left"]: + raise ValueError( + f"Padding side should be selected between 'right' and 'left', current value: {self.truncation_side}" + ) + self.model_input_names = kwargs.pop("model_input_names", self.model_input_names) self.deprecation_warnings = ( diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index cd7e9885f9..44c55b423c 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1415,6 +1415,47 @@ class TokenizerTesterMixin: **kwargs, ) + def test_truncation_side_in_kwargs(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + if self.test_rust_tokenizer: + tokenizer_r = self.rust_tokenizer_class.from_pretrained( + pretrained_name, truncation_side="left", **kwargs + ) + self.assertEqual(tokenizer_r.truncation_side, "left") + + tokenizer_r = self.rust_tokenizer_class.from_pretrained( + pretrained_name, truncation_side="right", **kwargs + ) + self.assertEqual(tokenizer_r.truncation_side, "right") + + self.assertRaises( + ValueError, + self.rust_tokenizer_class.from_pretrained, + pretrained_name, + truncation_side="unauthorized", + **kwargs, + ) + + if self.test_slow_tokenizer: + tokenizer_p = self.tokenizer_class.from_pretrained( + pretrained_name, truncation_side="left", **kwargs + ) + self.assertEqual(tokenizer_p.truncation_side, "left") + + tokenizer_p = self.tokenizer_class.from_pretrained( + pretrained_name, truncation_side="right", **kwargs + ) + self.assertEqual(tokenizer_p.truncation_side, "right") + + self.assertRaises( + ValueError, + self.tokenizer_class.from_pretrained, + pretrained_name, + truncation_side="unauthorized", + **kwargs, + ) + def test_right_and_left_padding(self): tokenizers = self.get_tokenizers(do_lower_case=False) for tokenizer in tokenizers: