[Wav2vec2] Fixed tokenization mistakes while adding single-char tokens to tokenizer (#11538)
* Fixed tokenization mistakes while adding single-char tokens to tokenizer * Added tests and Removed unnecessary comments. * finalize wav2vec2 tok * add more aggressive tests * Apply suggestions from code review * fix useless import Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -24,8 +24,8 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
|
||||
from ...file_utils import PaddingStrategy, TensorType, add_end_docstrings
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from ...tokenization_utils import PreTrainedTokenizer, _insert_one_token_to_ordered_list
|
||||
from ...tokenization_utils_base import AddedToken, BatchEncoding
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@@ -277,6 +277,61 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
||||
|
||||
return (vocab_file,)
|
||||
|
||||
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
|
||||
"""
|
||||
Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to
|
||||
it with indices starting from length of the current vocabulary.
|
||||
|
||||
Args:
|
||||
new_tokens (:obj:`List[str]`or :obj:`List[tokenizers.AddedToken]`):
|
||||
Token(s) to add in vocabulary. A token is only added if it's not already in the vocabulary (tested by
|
||||
checking if the tokenizer assign the index of the ``unk_token`` to them).
|
||||
special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not the tokens should be added as special tokens.
|
||||
|
||||
Returns:
|
||||
:obj:`int`: The number of tokens actually added to the vocabulary.
|
||||
|
||||
Examples::
|
||||
|
||||
# Let's see how to increase the vocabulary of Bert model and tokenizer
|
||||
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('facebook/wav2vec2-base-960h')
|
||||
model = Wav2Vec2ForCTC.from_pretrained('facebook/wav2vec2-base-960h')
|
||||
|
||||
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
|
||||
print('We have added', num_added_toks, 'tokens')
|
||||
# Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
"""
|
||||
new_tokens = [str(tok) for tok in new_tokens]
|
||||
|
||||
tokens_to_add = []
|
||||
for token in new_tokens:
|
||||
assert isinstance(token, str)
|
||||
if not special_tokens and hasattr(self, "do_lower_case") and self.do_lower_case:
|
||||
token = token.lower()
|
||||
if (
|
||||
token != self.unk_token
|
||||
and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
|
||||
and token not in tokens_to_add
|
||||
):
|
||||
tokens_to_add.append(token)
|
||||
if self.verbose:
|
||||
logger.info(f"Adding {token} to the vocabulary")
|
||||
|
||||
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
|
||||
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
|
||||
self.added_tokens_encoder.update(added_tok_encoder)
|
||||
self.added_tokens_decoder.update(added_tok_decoder)
|
||||
|
||||
# Make sure we don't split on any special tokens (even they were already in the vocab before)
|
||||
for token in tokens_to_add:
|
||||
if len(token) > 1:
|
||||
self._additional_special_tokens.append(AddedToken(token))
|
||||
_insert_one_token_to_ordered_list(self.unique_no_split_tokens, token)
|
||||
|
||||
return len(tokens_to_add)
|
||||
|
||||
|
||||
class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
|
||||
@@ -375,6 +375,38 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return Wav2Vec2CTCTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def test_tokenizer_add_token_chars(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
# check adding a single token
|
||||
tokenizer.add_tokens("x")
|
||||
token_ids = tokenizer("C x A").input_ids
|
||||
self.assertEqual(token_ids, [19, 4, 32, 4, 7])
|
||||
|
||||
tokenizer.add_tokens(["a", "b", "c"])
|
||||
token_ids = tokenizer("C a A c").input_ids
|
||||
self.assertEqual(token_ids, [19, 4, 33, 4, 7, 4, 35])
|
||||
|
||||
tokenizer.add_tokens(["a", "b", "c"])
|
||||
token_ids = tokenizer("CaA c").input_ids
|
||||
self.assertEqual(token_ids, [19, 33, 7, 4, 35])
|
||||
|
||||
def test_tokenizer_add_token_words(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
# check adding a single token
|
||||
tokenizer.add_tokens("xxx")
|
||||
token_ids = tokenizer("C xxx A B").input_ids
|
||||
self.assertEqual(token_ids, [19, 4, 32, 4, 7, 4, 24])
|
||||
|
||||
tokenizer.add_tokens(["aaa", "bbb", "ccc"])
|
||||
token_ids = tokenizer("C aaa A ccc B B").input_ids
|
||||
self.assertEqual(token_ids, [19, 4, 33, 4, 7, 4, 35, 4, 24, 4, 24])
|
||||
|
||||
tokenizer.add_tokens(["aaa", "bbb", "ccc"])
|
||||
token_ids = tokenizer("CaaaA ccc B B").input_ids
|
||||
self.assertEqual(token_ids, [19, 33, 7, 4, 35, 4, 24, 4, 24])
|
||||
|
||||
def test_tokenizer_decode(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
@@ -470,3 +502,55 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def test_pretrained_model_lists(self):
|
||||
# Wav2Vec2Model has no max model length => no testing
|
||||
pass
|
||||
|
||||
# overwrite from test_tokenization_common
|
||||
def test_add_tokens_tokenizer(self):
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
vocab_size = tokenizer.vocab_size
|
||||
all_size = len(tokenizer)
|
||||
|
||||
self.assertNotEqual(vocab_size, 0)
|
||||
|
||||
# We usually have added tokens from the start in tests because our vocab fixtures are
|
||||
# smaller than the original vocabs - let's not assert this
|
||||
# self.assertEqual(vocab_size, all_size)
|
||||
|
||||
new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"]
|
||||
added_toks = tokenizer.add_tokens(new_toks)
|
||||
vocab_size_2 = tokenizer.vocab_size
|
||||
all_size_2 = len(tokenizer)
|
||||
|
||||
self.assertNotEqual(vocab_size_2, 0)
|
||||
self.assertEqual(vocab_size, vocab_size_2)
|
||||
self.assertEqual(added_toks, len(new_toks))
|
||||
self.assertEqual(all_size_2, all_size + len(new_toks))
|
||||
|
||||
tokens = tokenizer.encode("aaaaa bbbbbb low cccccccccdddddddd l", add_special_tokens=False)
|
||||
|
||||
self.assertGreaterEqual(len(tokens), 4)
|
||||
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||||
self.assertGreater(tokens[-3], tokenizer.vocab_size - 1)
|
||||
|
||||
new_toks_2 = {"eos_token": ">>>>|||<||<<|<<", "pad_token": "<<<<<|||>|>>>>|>"}
|
||||
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
|
||||
vocab_size_3 = tokenizer.vocab_size
|
||||
all_size_3 = len(tokenizer)
|
||||
|
||||
self.assertNotEqual(vocab_size_3, 0)
|
||||
self.assertEqual(vocab_size, vocab_size_3)
|
||||
self.assertEqual(added_toks_2, len(new_toks_2))
|
||||
self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
|
||||
|
||||
tokens = tokenizer.encode(
|
||||
">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l", add_special_tokens=False
|
||||
)
|
||||
|
||||
self.assertGreaterEqual(len(tokens), 6)
|
||||
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||||
self.assertGreater(tokens[0], tokens[1])
|
||||
self.assertGreater(tokens[-3], tokenizer.vocab_size - 1)
|
||||
self.assertGreater(tokens[-3], tokens[-4])
|
||||
self.assertEqual(tokens[0], tokenizer.eos_token_id)
|
||||
self.assertEqual(tokens[-3], tokenizer.pad_token_id)
|
||||
|
||||
Reference in New Issue
Block a user