Fix for slow the bug tokenizer adding spaces to single id decodes (#32564)

* _decode signature change and quick return

* added bunch of decoding tests

* signature match and return

* added tests for decoding

* merged decoding test

* more tests for special tokens

* cosmetics

* fixed param

* ruffed the file

* refinement for single special tokens

* added test for single special tokens

* slight change to test name

Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com>

* minor change test name for skip tokens

Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com>

* killed already defined var

Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com>

* minor update with vars

Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com>

* killed already defined var once more

Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com>

---------

Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com>
This commit is contained in:
Duygu Altinok
2024-09-18 13:32:02 +03:00
committed by GitHub
parent e6d9f39dd7
commit 52e22cbf67
2 changed files with 71 additions and 2 deletions

View File

@@ -1077,7 +1077,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
def _decode( def _decode(
self, self,
token_ids: List[int], token_ids: Union[int, List[int]],
skip_special_tokens: bool = False, skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = None, clean_up_tokenization_spaces: bool = None,
spaces_between_special_tokens: bool = True, spaces_between_special_tokens: bool = True,
@@ -1086,6 +1086,10 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
# If given is a single id, prevents splitting the string in upcoming loop
if isinstance(filtered_tokens, str):
filtered_tokens = [filtered_tokens]
legacy_added_tokens = set(self._added_tokens_encoder.keys()) - set(self.all_special_tokens) | { legacy_added_tokens = set(self._added_tokens_encoder.keys()) - set(self.all_special_tokens) | {
token for token in self.additional_special_tokens if self.convert_tokens_to_ids(token) >= self.vocab_size token for token in self.additional_special_tokens if self.convert_tokens_to_ids(token) >= self.vocab_size
} }
@@ -1096,7 +1100,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
current_sub_text = [] current_sub_text = []
# TODO @ArthurZ in version 5, special tokens should be handled in convert_tokens_to_string, while _convert_tokens_to_string # TODO @ArthurZ in version 5, special tokens should be handled in convert_tokens_to_string, while _convert_tokens_to_string
for token in filtered_tokens: for token in filtered_tokens:
if skip_special_tokens and token in self.all_special_ids: if skip_special_tokens and token in self.all_special_tokens:
continue continue
if token in legacy_added_tokens: if token in legacy_added_tokens:
if current_sub_text: if current_sub_text:

View File

@@ -253,6 +253,71 @@ class TokenizerUtilsTest(unittest.TestCase):
self.assertTrue(isinstance(batch["input_ids"], np.ndarray)) self.assertTrue(isinstance(batch["input_ids"], np.ndarray))
self.assertEqual(batch["input_ids"].tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]]) self.assertEqual(batch["input_ids"].tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]])
@require_tokenizers
def test_decoding_single_token(self):
for tokenizer_class in [BertTokenizer, BertTokenizerFast]:
with self.subTest(f"{tokenizer_class}"):
tokenizer = tokenizer_class.from_pretrained("google-bert/bert-base-cased")
token_id = 2300
decoded_flat = tokenizer.decode(token_id)
decoded_list = tokenizer.decode([token_id])
self.assertEqual(decoded_flat, "Force")
self.assertEqual(decoded_list, "Force")
token_id = 0
decoded_flat = tokenizer.decode(token_id)
decoded_list = tokenizer.decode([token_id])
self.assertEqual(decoded_flat, "[PAD]")
self.assertEqual(decoded_list, "[PAD]")
last_item_id = tokenizer.vocab_size - 1
decoded_flat = tokenizer.decode(last_item_id)
decoded_list = tokenizer.decode([last_item_id])
self.assertEqual(decoded_flat, "##")
self.assertEqual(decoded_list, "##")
@require_tokenizers
def test_decoding_skip_special_tokens(self):
for tokenizer_class in [BertTokenizer, BertTokenizerFast]:
with self.subTest(f"{tokenizer_class}"):
tokenizer = tokenizer_class.from_pretrained("google-bert/bert-base-cased")
tokenizer.add_tokens([""], special_tokens=True)
# test special token with other tokens, skip the special tokens
sentence = "This is a beautiful flower ஐ"
ids = tokenizer(sentence)["input_ids"]
decoded_sent = tokenizer.decode(ids, skip_special_tokens=True)
self.assertEqual(decoded_sent, "This is a beautiful flower")
# test special token with other tokens, do not skip the special tokens
ids = tokenizer(sentence)["input_ids"]
decoded_sent = tokenizer.decode(ids, skip_special_tokens=False)
self.assertEqual(decoded_sent, "[CLS] This is a beautiful flower ஐ [SEP]")
# test special token stand alone, skip the special tokens
sentence = ""
ids = tokenizer(sentence)["input_ids"]
decoded_sent = tokenizer.decode(ids, skip_special_tokens=True)
self.assertEqual(decoded_sent, "")
# test special token stand alone, do not skip the special tokens
ids = tokenizer(sentence)["input_ids"]
decoded_sent = tokenizer.decode(ids, skip_special_tokens=False)
self.assertEqual(decoded_sent, "[CLS] ஐ [SEP]")
# test single special token alone, skip
pad_id = 0
decoded_sent = tokenizer.decode(pad_id, skip_special_tokens=True)
self.assertEqual(decoded_sent, "")
# test single special token alone, do not skip
decoded_sent = tokenizer.decode(pad_id, skip_special_tokens=False)
self.assertEqual(decoded_sent, "[PAD]")
@require_torch @require_torch
def test_padding_accepts_tensors_pt(self): def test_padding_accepts_tensors_pt(self):
import torch import torch