[Wav2vec2] Fixed tokenization mistakes while adding single-char tokens to tokenizer (#11538)
* Fixed tokenization mistakes while adding single-char tokens to tokenizer * Added tests and Removed unnecessary comments. * finalize wav2vec2 tok * add more aggressive tests * Apply suggestions from code review * fix useless import Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -375,6 +375,38 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return Wav2Vec2CTCTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def test_tokenizer_add_token_chars(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
# check adding a single token
|
||||
tokenizer.add_tokens("x")
|
||||
token_ids = tokenizer("C x A").input_ids
|
||||
self.assertEqual(token_ids, [19, 4, 32, 4, 7])
|
||||
|
||||
tokenizer.add_tokens(["a", "b", "c"])
|
||||
token_ids = tokenizer("C a A c").input_ids
|
||||
self.assertEqual(token_ids, [19, 4, 33, 4, 7, 4, 35])
|
||||
|
||||
tokenizer.add_tokens(["a", "b", "c"])
|
||||
token_ids = tokenizer("CaA c").input_ids
|
||||
self.assertEqual(token_ids, [19, 33, 7, 4, 35])
|
||||
|
||||
def test_tokenizer_add_token_words(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
# check adding a single token
|
||||
tokenizer.add_tokens("xxx")
|
||||
token_ids = tokenizer("C xxx A B").input_ids
|
||||
self.assertEqual(token_ids, [19, 4, 32, 4, 7, 4, 24])
|
||||
|
||||
tokenizer.add_tokens(["aaa", "bbb", "ccc"])
|
||||
token_ids = tokenizer("C aaa A ccc B B").input_ids
|
||||
self.assertEqual(token_ids, [19, 4, 33, 4, 7, 4, 35, 4, 24, 4, 24])
|
||||
|
||||
tokenizer.add_tokens(["aaa", "bbb", "ccc"])
|
||||
token_ids = tokenizer("CaaaA ccc B B").input_ids
|
||||
self.assertEqual(token_ids, [19, 33, 7, 4, 35, 4, 24, 4, 24])
|
||||
|
||||
def test_tokenizer_decode(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
@@ -470,3 +502,55 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def test_pretrained_model_lists(self):
|
||||
# Wav2Vec2Model has no max model length => no testing
|
||||
pass
|
||||
|
||||
# overwrite from test_tokenization_common
|
||||
def test_add_tokens_tokenizer(self):
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
vocab_size = tokenizer.vocab_size
|
||||
all_size = len(tokenizer)
|
||||
|
||||
self.assertNotEqual(vocab_size, 0)
|
||||
|
||||
# We usually have added tokens from the start in tests because our vocab fixtures are
|
||||
# smaller than the original vocabs - let's not assert this
|
||||
# self.assertEqual(vocab_size, all_size)
|
||||
|
||||
new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"]
|
||||
added_toks = tokenizer.add_tokens(new_toks)
|
||||
vocab_size_2 = tokenizer.vocab_size
|
||||
all_size_2 = len(tokenizer)
|
||||
|
||||
self.assertNotEqual(vocab_size_2, 0)
|
||||
self.assertEqual(vocab_size, vocab_size_2)
|
||||
self.assertEqual(added_toks, len(new_toks))
|
||||
self.assertEqual(all_size_2, all_size + len(new_toks))
|
||||
|
||||
tokens = tokenizer.encode("aaaaa bbbbbb low cccccccccdddddddd l", add_special_tokens=False)
|
||||
|
||||
self.assertGreaterEqual(len(tokens), 4)
|
||||
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||||
self.assertGreater(tokens[-3], tokenizer.vocab_size - 1)
|
||||
|
||||
new_toks_2 = {"eos_token": ">>>>|||<||<<|<<", "pad_token": "<<<<<|||>|>>>>|>"}
|
||||
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
|
||||
vocab_size_3 = tokenizer.vocab_size
|
||||
all_size_3 = len(tokenizer)
|
||||
|
||||
self.assertNotEqual(vocab_size_3, 0)
|
||||
self.assertEqual(vocab_size, vocab_size_3)
|
||||
self.assertEqual(added_toks_2, len(new_toks_2))
|
||||
self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
|
||||
|
||||
tokens = tokenizer.encode(
|
||||
">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l", add_special_tokens=False
|
||||
)
|
||||
|
||||
self.assertGreaterEqual(len(tokens), 6)
|
||||
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||||
self.assertGreater(tokens[0], tokens[1])
|
||||
self.assertGreater(tokens[-3], tokenizer.vocab_size - 1)
|
||||
self.assertGreater(tokens[-3], tokens[-4])
|
||||
self.assertEqual(tokens[0], tokenizer.eos_token_id)
|
||||
self.assertEqual(tokens[-3], tokenizer.pad_token_id)
|
||||
|
||||
Reference in New Issue
Block a user