[Wav2vec2] Fixed tokenization mistakes while adding single-char tokens to tokenizer (#11538)

* Fixed tokenization mistakes while adding single-char tokens to tokenizer

* Added tests and Removed unnecessary comments.

* finalize wav2vec2 tok

* add more aggressive tests

* Apply suggestions from code review

* fix useless import

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Muktan
2021-05-03 20:49:12 +05:30
committed by GitHub
parent f3cf8ae7b3
commit a721a5eefd
2 changed files with 141 additions and 2 deletions

View File

@@ -375,6 +375,38 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
kwargs.update(self.special_tokens_map)
return Wav2Vec2CTCTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def test_tokenizer_add_token_chars(self):
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
# check adding a single token
tokenizer.add_tokens("x")
token_ids = tokenizer("C x A").input_ids
self.assertEqual(token_ids, [19, 4, 32, 4, 7])
tokenizer.add_tokens(["a", "b", "c"])
token_ids = tokenizer("C a A c").input_ids
self.assertEqual(token_ids, [19, 4, 33, 4, 7, 4, 35])
tokenizer.add_tokens(["a", "b", "c"])
token_ids = tokenizer("CaA c").input_ids
self.assertEqual(token_ids, [19, 33, 7, 4, 35])
def test_tokenizer_add_token_words(self):
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
# check adding a single token
tokenizer.add_tokens("xxx")
token_ids = tokenizer("C xxx A B").input_ids
self.assertEqual(token_ids, [19, 4, 32, 4, 7, 4, 24])
tokenizer.add_tokens(["aaa", "bbb", "ccc"])
token_ids = tokenizer("C aaa A ccc B B").input_ids
self.assertEqual(token_ids, [19, 4, 33, 4, 7, 4, 35, 4, 24, 4, 24])
tokenizer.add_tokens(["aaa", "bbb", "ccc"])
token_ids = tokenizer("CaaaA ccc B B").input_ids
self.assertEqual(token_ids, [19, 33, 7, 4, 35, 4, 24, 4, 24])
def test_tokenizer_decode(self):
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
@@ -470,3 +502,55 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
def test_pretrained_model_lists(self):
# Wav2Vec2Model has no max model length => no testing
pass
# overwrite from test_tokenization_common
def test_add_tokens_tokenizer(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
vocab_size = tokenizer.vocab_size
all_size = len(tokenizer)
self.assertNotEqual(vocab_size, 0)
# We usually have added tokens from the start in tests because our vocab fixtures are
# smaller than the original vocabs - let's not assert this
# self.assertEqual(vocab_size, all_size)
new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"]
added_toks = tokenizer.add_tokens(new_toks)
vocab_size_2 = tokenizer.vocab_size
all_size_2 = len(tokenizer)
self.assertNotEqual(vocab_size_2, 0)
self.assertEqual(vocab_size, vocab_size_2)
self.assertEqual(added_toks, len(new_toks))
self.assertEqual(all_size_2, all_size + len(new_toks))
tokens = tokenizer.encode("aaaaa bbbbbb low cccccccccdddddddd l", add_special_tokens=False)
self.assertGreaterEqual(len(tokens), 4)
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
self.assertGreater(tokens[-3], tokenizer.vocab_size - 1)
new_toks_2 = {"eos_token": ">>>>|||<||<<|<<", "pad_token": "<<<<<|||>|>>>>|>"}
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
vocab_size_3 = tokenizer.vocab_size
all_size_3 = len(tokenizer)
self.assertNotEqual(vocab_size_3, 0)
self.assertEqual(vocab_size, vocab_size_3)
self.assertEqual(added_toks_2, len(new_toks_2))
self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
tokens = tokenizer.encode(
">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l", add_special_tokens=False
)
self.assertGreaterEqual(len(tokens), 6)
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
self.assertGreater(tokens[0], tokens[1])
self.assertGreater(tokens[-3], tokenizer.vocab_size - 1)
self.assertGreater(tokens[-3], tokens[-4])
self.assertEqual(tokens[0], tokenizer.eos_token_id)
self.assertEqual(tokens[-3], tokenizer.pad_token_id)