From 9f8bfe703c4e8b88f462b23ff4968b1385bd955b Mon Sep 17 00:00:00 2001 From: davidleonfdez <45669232+davidleonfdez@users.noreply.github.com> Date: Wed, 13 Apr 2022 12:49:06 +0100 Subject: [PATCH] Fix #16660 (tokenizers setters of ids of special tokens) (#16661) * Fix setters of *_token_id properties of SpecialTokensMixin * Test setters of common tokens ids * Move to a separate test checks of setters of tokens ids * Add independent test for ByT5 * Add Canine test * Test speech to text --- src/transformers/tokenization_utils_base.py | 16 ++++----- tests/byt5/test_tokenization_byt5.py | 38 +++++++++++++++++++++ tests/canine/test_tokenization_canine.py | 37 ++++++++++++++++++++ tests/test_tokenization_common.py | 37 ++++++++++++++++++++ 4 files changed, 120 insertions(+), 8 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index c76197f544..899c5d3a02 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1150,35 +1150,35 @@ class SpecialTokensMixin: @bos_token_id.setter def bos_token_id(self, value): - self._bos_token = self.convert_tokens_to_ids(value) + self._bos_token = self.convert_ids_to_tokens(value) if value is not None else None @eos_token_id.setter def eos_token_id(self, value): - self._eos_token = self.convert_tokens_to_ids(value) + self._eos_token = self.convert_ids_to_tokens(value) if value is not None else None @unk_token_id.setter def unk_token_id(self, value): - self._unk_token = self.convert_tokens_to_ids(value) + self._unk_token = self.convert_ids_to_tokens(value) if value is not None else None @sep_token_id.setter def sep_token_id(self, value): - self._sep_token = self.convert_tokens_to_ids(value) + self._sep_token = self.convert_ids_to_tokens(value) if value is not None else None @pad_token_id.setter def pad_token_id(self, value): - self._pad_token = self.convert_tokens_to_ids(value) + self._pad_token = self.convert_ids_to_tokens(value) if value is not None else None @cls_token_id.setter def cls_token_id(self, value): - self._cls_token = self.convert_tokens_to_ids(value) + self._cls_token = self.convert_ids_to_tokens(value) if value is not None else None @mask_token_id.setter def mask_token_id(self, value): - self._mask_token = self.convert_tokens_to_ids(value) + self._mask_token = self.convert_ids_to_tokens(value) if value is not None else None @additional_special_tokens_ids.setter def additional_special_tokens_ids(self, values): - self._additional_special_tokens = [self.convert_tokens_to_ids(value) for value in values] + self._additional_special_tokens = [self.convert_ids_to_tokens(value) for value in values] @property def special_tokens_map(self) -> Dict[str, Union[str, List[str]]]: diff --git a/tests/byt5/test_tokenization_byt5.py b/tests/byt5/test_tokenization_byt5.py index eb210530f0..7e4f97d374 100644 --- a/tests/byt5/test_tokenization_byt5.py +++ b/tests/byt5/test_tokenization_byt5.py @@ -332,3 +332,41 @@ class ByT5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): string = tokenizer.convert_tokens_to_string(tokens) self.assertIsInstance(string, str) + + # We need a different implementation of the test of the same name defined in TokenizerTesterMixin because this tokenizer + # doesn't have a vocab + def test_tokenizers_common_ids_setters(self): + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + attributes_list = [ + "bos_token", + "eos_token", + "unk_token", + "sep_token", + "pad_token", + "cls_token", + "mask_token", + ] + + token_id_to_test_setters = 0 + token_to_test_setters = tokenizer.convert_ids_to_tokens( + token_id_to_test_setters, skip_special_tokens=False + ) + + for attr in attributes_list: + setattr(tokenizer, attr + "_id", None) + self.assertEqual(getattr(tokenizer, attr), None) + self.assertEqual(getattr(tokenizer, attr + "_id"), None) + + setattr(tokenizer, attr + "_id", token_id_to_test_setters) + self.assertEqual(getattr(tokenizer, attr), token_to_test_setters) + self.assertEqual(getattr(tokenizer, attr + "_id"), token_id_to_test_setters) + + setattr(tokenizer, "additional_special_tokens_ids", []) + self.assertListEqual(getattr(tokenizer, "additional_special_tokens"), []) + self.assertListEqual(getattr(tokenizer, "additional_special_tokens_ids"), []) + + setattr(tokenizer, "additional_special_tokens_ids", [token_id_to_test_setters]) + self.assertListEqual(getattr(tokenizer, "additional_special_tokens"), [token_to_test_setters]) + self.assertListEqual(getattr(tokenizer, "additional_special_tokens_ids"), [token_id_to_test_setters]) diff --git a/tests/canine/test_tokenization_canine.py b/tests/canine/test_tokenization_canine.py index d894237ff5..0a949e6d78 100644 --- a/tests/canine/test_tokenization_canine.py +++ b/tests/canine/test_tokenization_canine.py @@ -271,6 +271,43 @@ class CanineTokenizationTest(TokenizerTesterMixin, unittest.TestCase): decoded = tokenizer.decode(encoded, spaces_between_special_tokens=self.space_between_special_tokens) self.assertIn(decoded, [output, output.lower()]) + # cannot use default `test_tokenizers_common_ids_setters` method because tokenizer has no vocab + def test_tokenizers_common_ids_setters(self): + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + attributes_list = [ + "bos_token", + "eos_token", + "unk_token", + "sep_token", + "pad_token", + "cls_token", + "mask_token", + ] + + token_to_test_setters = "a" + token_id_to_test_setters = ord(token_to_test_setters) + + for attr in attributes_list: + setattr(tokenizer, attr + "_id", None) + self.assertEqual(getattr(tokenizer, attr), None) + self.assertEqual(getattr(tokenizer, attr + "_id"), None) + + setattr(tokenizer, attr + "_id", token_id_to_test_setters) + self.assertEqual(getattr(tokenizer, attr), token_to_test_setters) + self.assertEqual(getattr(tokenizer, attr + "_id"), token_id_to_test_setters) + + setattr(tokenizer, "additional_special_tokens_ids", []) + self.assertListEqual(getattr(tokenizer, "additional_special_tokens"), []) + self.assertListEqual(getattr(tokenizer, "additional_special_tokens_ids"), []) + + additional_special_token_id = 0xE006 + additional_special_token = chr(additional_special_token_id) + setattr(tokenizer, "additional_special_tokens_ids", [additional_special_token_id]) + self.assertListEqual(getattr(tokenizer, "additional_special_tokens"), [additional_special_token]) + self.assertListEqual(getattr(tokenizer, "additional_special_tokens_ids"), [additional_special_token_id]) + # tokenizer has a fixed vocab_size (namely all possible unicode code points) def test_add_tokens_tokenizer(self): pass diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 2d26d76b9a..fe16e5e1cd 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -540,6 +540,43 @@ class TokenizerTesterMixin: for attr in attributes_list: self.assertTrue(hasattr(tokenizer, attr)) + def test_tokenizers_common_ids_setters(self): + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + attributes_list = [ + "bos_token", + "eos_token", + "unk_token", + "sep_token", + "pad_token", + "cls_token", + "mask_token", + ] + + vocab = tokenizer.get_vocab() + token_id_to_test_setters = next(iter(vocab.values())) + token_to_test_setters = tokenizer.convert_ids_to_tokens( + token_id_to_test_setters, skip_special_tokens=False + ) + + for attr in attributes_list: + setattr(tokenizer, attr + "_id", None) + self.assertEqual(getattr(tokenizer, attr), None) + self.assertEqual(getattr(tokenizer, attr + "_id"), None) + + setattr(tokenizer, attr + "_id", token_id_to_test_setters) + self.assertEqual(getattr(tokenizer, attr), token_to_test_setters) + self.assertEqual(getattr(tokenizer, attr + "_id"), token_id_to_test_setters) + + setattr(tokenizer, "additional_special_tokens_ids", []) + self.assertListEqual(getattr(tokenizer, "additional_special_tokens"), []) + self.assertListEqual(getattr(tokenizer, "additional_special_tokens_ids"), []) + + setattr(tokenizer, "additional_special_tokens_ids", [token_id_to_test_setters]) + self.assertListEqual(getattr(tokenizer, "additional_special_tokens"), [token_to_test_setters]) + self.assertListEqual(getattr(tokenizer, "additional_special_tokens_ids"), [token_id_to_test_setters]) + def test_save_and_load_tokenizer(self): # safety check on max_len default value so we are sure the test works tokenizers = self.get_tokenizers()