From f9cc333805c47665c6afee8b5867931e54abe0c6 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 25 Jul 2023 18:45:01 +0200 Subject: [PATCH] [ `PreTrainedTokenizerFast`] Keep properties from fast tokenizer (#25053) * draft solution * use `setdefault` * nits * add tests and fix truncation issue * fix test * test passes locally * quality * updates * update tsets --- src/transformers/tokenization_utils_fast.py | 20 ++++++++ tests/tokenization/test_tokenization_fast.py | 52 ++++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/src/transformers/tokenization_utils_fast.py b/src/transformers/tokenization_utils_fast.py index 471221e713..ced9d79593 100644 --- a/src/transformers/tokenization_utils_fast.py +++ b/src/transformers/tokenization_utils_fast.py @@ -132,6 +132,26 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): self._decode_use_source_tokenizer = False + _truncation = self._tokenizer.truncation + + if _truncation is not None: + self._tokenizer.enable_truncation(**_truncation) + kwargs.setdefault("max_length", _truncation["max_length"]) + kwargs.setdefault("truncation_side", _truncation["direction"]) + kwargs.setdefault("stride", _truncation["stride"]) + kwargs.setdefault("truncation_strategy", _truncation["strategy"]) + else: + self._tokenizer.no_truncation() + + _padding = self._tokenizer.padding + if _padding is not None: + self._tokenizer.enable_padding(**_padding) + kwargs.setdefault("pad_token", _padding["pad_token"]) + kwargs.setdefault("pad_token_type_id", _padding["pad_type_id"]) + kwargs.setdefault("padding_side", _padding["direction"]) + kwargs.setdefault("max_length", _padding["length"]) + kwargs.setdefault("pad_to_multiple_of", _padding["pad_to_multiple_of"]) + # We call this after having initialized the backend tokenizer because we update it. super().__init__(**kwargs) diff --git a/tests/tokenization/test_tokenization_fast.py b/tests/tokenization/test_tokenization_fast.py index da98d17d77..c6259610aa 100644 --- a/tests/tokenization/test_tokenization_fast.py +++ b/tests/tokenization/test_tokenization_fast.py @@ -109,6 +109,58 @@ class PreTrainedTokenizationFastTest(TokenizerTesterMixin, unittest.TestCase): encoding_ids = new_tokenizer.encode("a🤗") self.assertEqual(encoding_ids, [64, 172, 253, 97, 245]) + def test_init_from_tokenizers_model(self): + from tokenizers import Tokenizer + + sentences = ["Hello, y'all!", "How are you 😁 ? There should not be any issue right?"] + + tokenizer = Tokenizer.from_pretrained("t5-base") + # Enable padding + tokenizer.enable_padding(pad_id=0, pad_token="", length=512, pad_to_multiple_of=8) + self.assertEqual( + tokenizer.padding, + { + "length": 512, + "pad_to_multiple_of": 8, + "pad_id": 0, + "pad_token": "", + "pad_type_id": 0, + "direction": "right", + }, + ) + fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer) + tmpdirname = tempfile.mkdtemp() + fast_tokenizer.save_pretrained(tmpdirname) + fast_from_saved = PreTrainedTokenizerFast.from_pretrained(tmpdirname) + for tok in [fast_tokenizer, fast_from_saved]: + self.assertEqual(tok.pad_token_id, 0) + self.assertEqual(tok.padding_side, "right") + self.assertEqual(tok.pad_token, "") + self.assertEqual(tok.init_kwargs["max_length"], 512) + self.assertEqual(tok.init_kwargs["pad_to_multiple_of"], 8) + # fmt: off + self.assertEqual(tok(sentences, padding = True), {'input_ids': [[8774, 6, 3, 63, 31, 1748, 55, 1, 0, 0, 0, 0,0, 0, 0, 0],[ 571, 33, 25, 3, 2, 3, 58, 290, 225, 59, 36, 136, 962, 269, 58, 1]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}) + # fmt: on + + tokenizer.enable_truncation(8, stride=0, strategy="longest_first", direction="right") + self.assertEqual( + tokenizer.truncation, {"max_length": 8, "stride": 0, "strategy": "longest_first", "direction": "right"} + ) + fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer) + tmpdirname = tempfile.mkdtemp() + fast_tokenizer.save_pretrained(tmpdirname) + fast_from_saved = PreTrainedTokenizerFast.from_pretrained(tmpdirname) + for tok in [fast_tokenizer, fast_from_saved]: + self.assertEqual(tok.truncation_side, "right") + self.assertEqual(tok.init_kwargs["truncation_strategy"], "longest_first") + self.assertEqual(tok.init_kwargs["max_length"], 8) + self.assertEqual(tok.init_kwargs["stride"], 0) + # NOTE even if the model has a default max_length, it is not used... + # thus tok(sentences, truncation = True) does nothing and does not warn either + # fmt: off + self.assertEqual(tok(sentences, truncation = True, max_length = 8), {'input_ids': [[8774, 6, 3, 63, 31, 1748, 55, 1],[ 571, 33, 25, 3, 2, 3, 58, 1]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1]]}) + # fmt: on + @require_tokenizers class TokenizerVersioningTest(unittest.TestCase):