From 52e22cbf677f58a1cca347117e8974fa50c9f2d7 Mon Sep 17 00:00:00 2001 From: Duygu Altinok Date: Wed, 18 Sep 2024 13:32:02 +0300 Subject: [PATCH] Fix for slow the bug tokenizer adding spaces to single id decodes (#32564) * _decode signature change and quick return * added bunch of decoding tests * signature match and return * added tests for decoding * merged decoding test * more tests for special tokens * cosmetics * fixed param * ruffed the file * refinement for single special tokens * added test for single special tokens * slight change to test name Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com> * minor change test name for skip tokens Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com> * killed already defined var Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com> * minor update with vars Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com> * killed already defined var once more Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com> --------- Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com> --- src/transformers/tokenization_utils.py | 8 ++- tests/tokenization/test_tokenization_utils.py | 65 +++++++++++++++++++ 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index 6a5bff3679..df13a029a6 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -1077,7 +1077,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): def _decode( self, - token_ids: List[int], + token_ids: Union[int, List[int]], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = None, spaces_between_special_tokens: bool = True, @@ -1086,6 +1086,10 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) + # If given is a single id, prevents splitting the string in upcoming loop + if isinstance(filtered_tokens, str): + filtered_tokens = [filtered_tokens] + legacy_added_tokens = set(self._added_tokens_encoder.keys()) - set(self.all_special_tokens) | { token for token in self.additional_special_tokens if self.convert_tokens_to_ids(token) >= self.vocab_size } @@ -1096,7 +1100,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase): current_sub_text = [] # TODO @ArthurZ in version 5, special tokens should be handled in convert_tokens_to_string, while _convert_tokens_to_string for token in filtered_tokens: - if skip_special_tokens and token in self.all_special_ids: + if skip_special_tokens and token in self.all_special_tokens: continue if token in legacy_added_tokens: if current_sub_text: diff --git a/tests/tokenization/test_tokenization_utils.py b/tests/tokenization/test_tokenization_utils.py index b43923df84..2c8f71ba97 100644 --- a/tests/tokenization/test_tokenization_utils.py +++ b/tests/tokenization/test_tokenization_utils.py @@ -253,6 +253,71 @@ class TokenizerUtilsTest(unittest.TestCase): self.assertTrue(isinstance(batch["input_ids"], np.ndarray)) self.assertEqual(batch["input_ids"].tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]]) + @require_tokenizers + def test_decoding_single_token(self): + for tokenizer_class in [BertTokenizer, BertTokenizerFast]: + with self.subTest(f"{tokenizer_class}"): + tokenizer = tokenizer_class.from_pretrained("google-bert/bert-base-cased") + + token_id = 2300 + decoded_flat = tokenizer.decode(token_id) + decoded_list = tokenizer.decode([token_id]) + + self.assertEqual(decoded_flat, "Force") + self.assertEqual(decoded_list, "Force") + + token_id = 0 + decoded_flat = tokenizer.decode(token_id) + decoded_list = tokenizer.decode([token_id]) + + self.assertEqual(decoded_flat, "[PAD]") + self.assertEqual(decoded_list, "[PAD]") + + last_item_id = tokenizer.vocab_size - 1 + decoded_flat = tokenizer.decode(last_item_id) + decoded_list = tokenizer.decode([last_item_id]) + + self.assertEqual(decoded_flat, "##:") + self.assertEqual(decoded_list, "##:") + + @require_tokenizers + def test_decoding_skip_special_tokens(self): + for tokenizer_class in [BertTokenizer, BertTokenizerFast]: + with self.subTest(f"{tokenizer_class}"): + tokenizer = tokenizer_class.from_pretrained("google-bert/bert-base-cased") + tokenizer.add_tokens(["ஐ"], special_tokens=True) + + # test special token with other tokens, skip the special tokens + sentence = "This is a beautiful flower ஐ" + ids = tokenizer(sentence)["input_ids"] + decoded_sent = tokenizer.decode(ids, skip_special_tokens=True) + self.assertEqual(decoded_sent, "This is a beautiful flower") + + # test special token with other tokens, do not skip the special tokens + ids = tokenizer(sentence)["input_ids"] + decoded_sent = tokenizer.decode(ids, skip_special_tokens=False) + self.assertEqual(decoded_sent, "[CLS] This is a beautiful flower ஐ [SEP]") + + # test special token stand alone, skip the special tokens + sentence = "ஐ" + ids = tokenizer(sentence)["input_ids"] + decoded_sent = tokenizer.decode(ids, skip_special_tokens=True) + self.assertEqual(decoded_sent, "") + + # test special token stand alone, do not skip the special tokens + ids = tokenizer(sentence)["input_ids"] + decoded_sent = tokenizer.decode(ids, skip_special_tokens=False) + self.assertEqual(decoded_sent, "[CLS] ஐ [SEP]") + + # test single special token alone, skip + pad_id = 0 + decoded_sent = tokenizer.decode(pad_id, skip_special_tokens=True) + self.assertEqual(decoded_sent, "") + + # test single special token alone, do not skip + decoded_sent = tokenizer.decode(pad_id, skip_special_tokens=False) + self.assertEqual(decoded_sent, "[PAD]") + @require_torch def test_padding_accepts_tensors_pt(self): import torch