From 5e737018e1fcb22c8b76052058279552a8d6c806 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Thu, 28 May 2020 10:54:45 -0400 Subject: [PATCH] Fix add_special_tokens on fast tokenizers (#4531) --- src/transformers/tokenization_utils.py | 11 ++++++++--- tests/test_tokenization_fast.py | 3 ++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index b8ed4b5b8c..9e137b853d 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -2400,15 +2400,20 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): def add_special_tokens(self, special_tokens_dict: dict) -> int: # Map special tokens to class attributes (self.pad_token...) - num_added_tokens = super().add_special_tokens(special_tokens_dict) + super().add_special_tokens(special_tokens_dict) # If the backend tokenizer the only specificities of special tokens are that # - they will never be processed by the model, and # - they will be removed while decoding. # But they are not mapped to special attributes in the backend so we can just # send a list. - tokens = flatten(special_tokens_dict.values()) - self._tokenizer.add_special_tokens(tokens) + tokens = [] + for token in special_tokens_dict.values(): + if isinstance(token, list): + tokens += token + else: + tokens += [token] + num_added_tokens = self._tokenizer.add_special_tokens(tokens) return num_added_tokens diff --git a/tests/test_tokenization_fast.py b/tests/test_tokenization_fast.py index 5de1b589a6..5a2f3b04d1 100644 --- a/tests/test_tokenization_fast.py +++ b/tests/test_tokenization_fast.py @@ -221,6 +221,7 @@ class CommonFastTokenizerTest(unittest.TestCase): self.assertEqual(len(tokenizer_r), vocab_size + 3) self.assertEqual(tokenizer_r.add_special_tokens({}), 0) + self.assertEqual(tokenizer_r.add_special_tokens({"bos_token": "[BOS]", "eos_token": "[EOS]"}), 2) self.assertRaises( AssertionError, tokenizer_r.add_special_tokens, {"additional_special_tokens": ""} ) @@ -228,7 +229,7 @@ class CommonFastTokenizerTest(unittest.TestCase): self.assertEqual( tokenizer_r.add_special_tokens({"additional_special_tokens": ["", ""]}), 2 ) - self.assertEqual(len(tokenizer_r), vocab_size + 6) + self.assertEqual(len(tokenizer_r), vocab_size + 8) def assert_offsets_mapping(self, tokenizer_r): text = "Wonderful no inspiration example with subtoken"