Fix for slow the bug tokenizer adding spaces to single id decodes (#32564)
* _decode signature change and quick return * added bunch of decoding tests * signature match and return * added tests for decoding * merged decoding test * more tests for special tokens * cosmetics * fixed param * ruffed the file * refinement for single special tokens * added test for single special tokens * slight change to test name Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com> * minor change test name for skip tokens Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com> * killed already defined var Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com> * minor update with vars Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com> * killed already defined var once more Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com> --------- Co-authored-by: Ita Zaporozhets <31893021+itazap@users.noreply.github.com>
This commit is contained in:
@@ -253,6 +253,71 @@ class TokenizerUtilsTest(unittest.TestCase):
|
||||
self.assertTrue(isinstance(batch["input_ids"], np.ndarray))
|
||||
self.assertEqual(batch["input_ids"].tolist(), [[0, 1, 2, tokenizer.pad_token_id], [0, 1, 2, 3]])
|
||||
|
||||
@require_tokenizers
|
||||
def test_decoding_single_token(self):
|
||||
for tokenizer_class in [BertTokenizer, BertTokenizerFast]:
|
||||
with self.subTest(f"{tokenizer_class}"):
|
||||
tokenizer = tokenizer_class.from_pretrained("google-bert/bert-base-cased")
|
||||
|
||||
token_id = 2300
|
||||
decoded_flat = tokenizer.decode(token_id)
|
||||
decoded_list = tokenizer.decode([token_id])
|
||||
|
||||
self.assertEqual(decoded_flat, "Force")
|
||||
self.assertEqual(decoded_list, "Force")
|
||||
|
||||
token_id = 0
|
||||
decoded_flat = tokenizer.decode(token_id)
|
||||
decoded_list = tokenizer.decode([token_id])
|
||||
|
||||
self.assertEqual(decoded_flat, "[PAD]")
|
||||
self.assertEqual(decoded_list, "[PAD]")
|
||||
|
||||
last_item_id = tokenizer.vocab_size - 1
|
||||
decoded_flat = tokenizer.decode(last_item_id)
|
||||
decoded_list = tokenizer.decode([last_item_id])
|
||||
|
||||
self.assertEqual(decoded_flat, "##:")
|
||||
self.assertEqual(decoded_list, "##:")
|
||||
|
||||
@require_tokenizers
|
||||
def test_decoding_skip_special_tokens(self):
|
||||
for tokenizer_class in [BertTokenizer, BertTokenizerFast]:
|
||||
with self.subTest(f"{tokenizer_class}"):
|
||||
tokenizer = tokenizer_class.from_pretrained("google-bert/bert-base-cased")
|
||||
tokenizer.add_tokens(["ஐ"], special_tokens=True)
|
||||
|
||||
# test special token with other tokens, skip the special tokens
|
||||
sentence = "This is a beautiful flower ஐ"
|
||||
ids = tokenizer(sentence)["input_ids"]
|
||||
decoded_sent = tokenizer.decode(ids, skip_special_tokens=True)
|
||||
self.assertEqual(decoded_sent, "This is a beautiful flower")
|
||||
|
||||
# test special token with other tokens, do not skip the special tokens
|
||||
ids = tokenizer(sentence)["input_ids"]
|
||||
decoded_sent = tokenizer.decode(ids, skip_special_tokens=False)
|
||||
self.assertEqual(decoded_sent, "[CLS] This is a beautiful flower ஐ [SEP]")
|
||||
|
||||
# test special token stand alone, skip the special tokens
|
||||
sentence = "ஐ"
|
||||
ids = tokenizer(sentence)["input_ids"]
|
||||
decoded_sent = tokenizer.decode(ids, skip_special_tokens=True)
|
||||
self.assertEqual(decoded_sent, "")
|
||||
|
||||
# test special token stand alone, do not skip the special tokens
|
||||
ids = tokenizer(sentence)["input_ids"]
|
||||
decoded_sent = tokenizer.decode(ids, skip_special_tokens=False)
|
||||
self.assertEqual(decoded_sent, "[CLS] ஐ [SEP]")
|
||||
|
||||
# test single special token alone, skip
|
||||
pad_id = 0
|
||||
decoded_sent = tokenizer.decode(pad_id, skip_special_tokens=True)
|
||||
self.assertEqual(decoded_sent, "")
|
||||
|
||||
# test single special token alone, do not skip
|
||||
decoded_sent = tokenizer.decode(pad_id, skip_special_tokens=False)
|
||||
self.assertEqual(decoded_sent, "[PAD]")
|
||||
|
||||
@require_torch
|
||||
def test_padding_accepts_tensors_pt(self):
|
||||
import torch
|
||||
|
||||
Reference in New Issue
Block a user