From 6e87010060ddfb334293df4228123ed7a4ca6ad7 Mon Sep 17 00:00:00 2001 From: SaulLu <55560583+SaulLu@users.noreply.github.com> Date: Fri, 16 Jul 2021 18:26:54 +0200 Subject: [PATCH] Preserve `list` type of `additional_special_tokens` in `special_token_map` (#12759) * preserve type of `additional_special_tokens` in `special_token_map` * format * Update src/transformers/tokenization_utils_base.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/tokenization_utils_base.py | 6 +++++- tests/test_tokenization_common.py | 4 ++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 99fa84ad8e..59ff00f0f7 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1192,7 +1192,11 @@ class SpecialTokensMixin: for attr in self.SPECIAL_TOKENS_ATTRIBUTES: attr_value = getattr(self, "_" + attr) if attr_value: - set_attr[attr] = str(attr_value) + set_attr[attr] = ( + type(attr_value)(str(attr_value_sub) for attr_value_sub in attr_value) + if isinstance(attr_value, (list, tuple)) + else str(attr_value) + ) return set_attr @property diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 0a662cc62c..dbc6af764e 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -2462,6 +2462,10 @@ class TokenizerTesterMixin: self.assertEqual( tokenizer_r.add_special_tokens({"additional_special_tokens": ["", ""]}), 2 ) + self.assertIn("", tokenizer_r.special_tokens_map["additional_special_tokens"]) + self.assertIsInstance(tokenizer_r.special_tokens_map["additional_special_tokens"], list) + self.assertGreaterEqual(len(tokenizer_r.special_tokens_map["additional_special_tokens"]), 2) + self.assertEqual(len(tokenizer_r), vocab_size + 8) def test_offsets_mapping(self):