Fix for slow the bug tokenizer adding spaces to single id decodes (#32564)
* _decode signature change and quick return * added bunch of decoding tests * signature match and return * added tests for decoding * merged decoding test * more tests for special tokens * cosmetics * fixed param * ruffed the file * refinement for single special tokens * added test for single special tokens * slight change to test name Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com> * minor change test name for skip tokens Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com> * killed already defined var Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com> * minor update with vars Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com> * killed already defined var once more Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com> --------- Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com>
This commit is contained in:
@@ -1077,7 +1077,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||||||
|
|
||||||
def _decode(
|
def _decode(
|
||||||
self,
|
self,
|
||||||
token_ids: List[int],
|
token_ids: Union[int, List[int]],
|
||||||
skip_special_tokens: bool = False,
|
skip_special_tokens: bool = False,
|
||||||
clean_up_tokenization_spaces: bool = None,
|
clean_up_tokenization_spaces: bool = None,
|
||||||
spaces_between_special_tokens: bool = True,
|
spaces_between_special_tokens: bool = True,
|
||||||
@@ -1086,6 +1086,10 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||||||
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
|
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
|
||||||
|
|
||||||
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
||||||
|
# If given is a single id, prevents splitting the string in upcoming loop
|
||||||
|
if isinstance(filtered_tokens, str):
|
||||||
|
filtered_tokens = [filtered_tokens]
|
||||||
|
|
||||||
legacy_added_tokens = set(self._added_tokens_encoder.keys()) - set(self.all_special_tokens) | {
|
legacy_added_tokens = set(self._added_tokens_encoder.keys()) - set(self.all_special_tokens) | {
|
||||||
token for token in self.additional_special_tokens if self.convert_tokens_to_ids(token) >= self.vocab_size
|
token for token in self.additional_special_tokens if self.convert_tokens_to_ids(token) >= self.vocab_size
|
||||||
}
|
}
|
||||||
@@ -1096,7 +1100,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||||||
current_sub_text = []
|
current_sub_text = []
|
||||||
# TODO @ArthurZ in version 5, special tokens should be handled in convert_tokens_to_string, while _convert_tokens_to_string
|
# TODO @ArthurZ in version 5, special tokens should be handled in convert_tokens_to_string, while _convert_tokens_to_string
|
||||||
for token in filtered_tokens:
|
for token in filtered_tokens:
|
||||||
if skip_special_tokens and token in self.all_special_ids:
|
if skip_special_tokens and token in self.all_special_tokens:
|
||||||
continue
|
continue
|
||||||
if token in legacy_added_tokens:
|
if token in legacy_added_tokens:
|
||||||
if current_sub_text:
|
if current_sub_text:
|
||||||
|
|||||||
@@ -253,6 +253,71 @@ class TokenizerUtilsTest(unittest.TestCase):
|
|||||||
self.assertTrue(isinstance(batch["input_ids"], np.ndarray))
|
self.assertTrue(isinstance(batch["input_ids"], np.ndarray))
|
||||||
self.assertEqual(batch["input_ids"].tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]])
|
self.assertEqual(batch["input_ids"].tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]])
|
||||||
|
|
||||||
|
@require_tokenizers
|
||||||
|
def test_decoding_single_token(self):
|
||||||
|
for tokenizer_class in [BertTokenizer, BertTokenizerFast]:
|
||||||
|
with self.subTest(f"{tokenizer_class}"):
|
||||||
|
tokenizer = tokenizer_class.from_pretrained("google-bert/bert-base-cased")
|
||||||
|
|
||||||
|
token_id = 2300
|
||||||
|
decoded_flat = tokenizer.decode(token_id)
|
||||||
|
decoded_list = tokenizer.decode([token_id])
|
||||||
|
|
||||||
|
self.assertEqual(decoded_flat, "Force")
|
||||||
|
self.assertEqual(decoded_list, "Force")
|
||||||
|
|
||||||
|
token_id = 0
|
||||||
|
decoded_flat = tokenizer.decode(token_id)
|
||||||
|
decoded_list = tokenizer.decode([token_id])
|
||||||
|
|
||||||
|
self.assertEqual(decoded_flat, "[PAD]")
|
||||||
|
self.assertEqual(decoded_list, "[PAD]")
|
||||||
|
|
||||||
|
last_item_id = tokenizer.vocab_size - 1
|
||||||
|
decoded_flat = tokenizer.decode(last_item_id)
|
||||||
|
decoded_list = tokenizer.decode([last_item_id])
|
||||||
|
|
||||||
|
self.assertEqual(decoded_flat, "##:")
|
||||||
|
self.assertEqual(decoded_list, "##:")
|
||||||
|
|
||||||
|
@require_tokenizers
|
||||||
|
def test_decoding_skip_special_tokens(self):
|
||||||
|
for tokenizer_class in [BertTokenizer, BertTokenizerFast]:
|
||||||
|
with self.subTest(f"{tokenizer_class}"):
|
||||||
|
tokenizer = tokenizer_class.from_pretrained("google-bert/bert-base-cased")
|
||||||
|
tokenizer.add_tokens(["ஐ"], special_tokens=True)
|
||||||
|
|
||||||
|
# test special token with other tokens, skip the special tokens
|
||||||
|
sentence = "This is a beautiful flower ஐ"
|
||||||
|
ids = tokenizer(sentence)["input_ids"]
|
||||||
|
decoded_sent = tokenizer.decode(ids, skip_special_tokens=True)
|
||||||
|
self.assertEqual(decoded_sent, "This is a beautiful flower")
|
||||||
|
|
||||||
|
# test special token with other tokens, do not skip the special tokens
|
||||||
|
ids = tokenizer(sentence)["input_ids"]
|
||||||
|
decoded_sent = tokenizer.decode(ids, skip_special_tokens=False)
|
||||||
|
self.assertEqual(decoded_sent, "[CLS] This is a beautiful flower ஐ [SEP]")
|
||||||
|
|
||||||
|
# test special token stand alone, skip the special tokens
|
||||||
|
sentence = "ஐ"
|
||||||
|
ids = tokenizer(sentence)["input_ids"]
|
||||||
|
decoded_sent = tokenizer.decode(ids, skip_special_tokens=True)
|
||||||
|
self.assertEqual(decoded_sent, "")
|
||||||
|
|
||||||
|
# test special token stand alone, do not skip the special tokens
|
||||||
|
ids = tokenizer(sentence)["input_ids"]
|
||||||
|
decoded_sent = tokenizer.decode(ids, skip_special_tokens=False)
|
||||||
|
self.assertEqual(decoded_sent, "[CLS] ஐ [SEP]")
|
||||||
|
|
||||||
|
# test single special token alone, skip
|
||||||
|
pad_id = 0
|
||||||
|
decoded_sent = tokenizer.decode(pad_id, skip_special_tokens=True)
|
||||||
|
self.assertEqual(decoded_sent, "")
|
||||||
|
|
||||||
|
# test single special token alone, do not skip
|
||||||
|
decoded_sent = tokenizer.decode(pad_id, skip_special_tokens=False)
|
||||||
|
self.assertEqual(decoded_sent, "[PAD]")
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_padding_accepts_tensors_pt(self):
|
def test_padding_accepts_tensors_pt(self):
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
Reference in New Issue
Block a user