[Wav2vec2] Fixed tokenization mistakes while adding single-char tokens to tokenizer (#11538)

* Fixed tokenization mistakes while adding single-char tokens to tokenizer

* Added tests and Removed unnecessary comments.

* finalize wav2vec2 tok

* add more aggressive tests

* Apply suggestions from code review

* fix useless import

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Muktan
2021-05-03 20:49:12 +05:30
committed by GitHub
parent f3cf8ae7b3
commit a721a5eefd
2 changed files with 141 additions and 2 deletions

View File

@@ -24,8 +24,8 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
from ...file_utils import PaddingStrategy, TensorType, add_end_docstrings from ...file_utils import PaddingStrategy, TensorType, add_end_docstrings
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer, _insert_one_token_to_ordered_list
from ...tokenization_utils_base import BatchEncoding from ...tokenization_utils_base import AddedToken, BatchEncoding
from ...utils import logging from ...utils import logging
@@ -277,6 +277,61 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
return (vocab_file,) return (vocab_file,)
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
"""
Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to
it with indices starting from length of the current vocabulary.
Args:
new_tokens (:obj:`List[str]`or :obj:`List[tokenizers.AddedToken]`):
Token(s) to add in vocabulary. A token is only added if it's not already in the vocabulary (tested by
checking if the tokenizer assign the index of the ``unk_token`` to them).
special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the tokens should be added as special tokens.
Returns:
:obj:`int`: The number of tokens actually added to the vocabulary.
Examples::
# Let's see how to increase the vocabulary of Bert model and tokenizer
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('facebook/wav2vec2-base-960h')
model = Wav2Vec2ForCTC.from_pretrained('facebook/wav2vec2-base-960h')
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
print('We have added', num_added_toks, 'tokens')
# Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
model.resize_token_embeddings(len(tokenizer))
"""
new_tokens = [str(tok) for tok in new_tokens]
tokens_to_add = []
for token in new_tokens:
assert isinstance(token, str)
if not special_tokens and hasattr(self, "do_lower_case") and self.do_lower_case:
token = token.lower()
if (
token != self.unk_token
and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
and token not in tokens_to_add
):
tokens_to_add.append(token)
if self.verbose:
logger.info(f"Adding {token} to the vocabulary")
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder)
self.added_tokens_decoder.update(added_tok_decoder)
# Make sure we don't split on any special tokens (even they were already in the vocab before)
for token in tokens_to_add:
if len(token) > 1:
self._additional_special_tokens.append(AddedToken(token))
_insert_one_token_to_ordered_list(self.unique_no_split_tokens, token)
return len(tokens_to_add)
class Wav2Vec2Tokenizer(PreTrainedTokenizer): class Wav2Vec2Tokenizer(PreTrainedTokenizer):
""" """

View File

@@ -375,6 +375,38 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
kwargs.update(self.special_tokens_map) kwargs.update(self.special_tokens_map)
return Wav2Vec2CTCTokenizer.from_pretrained(self.tmpdirname, **kwargs) return Wav2Vec2CTCTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def test_tokenizer_add_token_chars(self):
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
# check adding a single token
tokenizer.add_tokens("x")
token_ids = tokenizer("C x A").input_ids
self.assertEqual(token_ids, [19, 4, 32, 4, 7])
tokenizer.add_tokens(["a", "b", "c"])
token_ids = tokenizer("C a A c").input_ids
self.assertEqual(token_ids, [19, 4, 33, 4, 7, 4, 35])
tokenizer.add_tokens(["a", "b", "c"])
token_ids = tokenizer("CaA c").input_ids
self.assertEqual(token_ids, [19, 33, 7, 4, 35])
def test_tokenizer_add_token_words(self):
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
# check adding a single token
tokenizer.add_tokens("xxx")
token_ids = tokenizer("C xxx A B").input_ids
self.assertEqual(token_ids, [19, 4, 32, 4, 7, 4, 24])
tokenizer.add_tokens(["aaa", "bbb", "ccc"])
token_ids = tokenizer("C aaa A ccc B B").input_ids
self.assertEqual(token_ids, [19, 4, 33, 4, 7, 4, 35, 4, 24, 4, 24])
tokenizer.add_tokens(["aaa", "bbb", "ccc"])
token_ids = tokenizer("CaaaA ccc B B").input_ids
self.assertEqual(token_ids, [19, 33, 7, 4, 35, 4, 24, 4, 24])
def test_tokenizer_decode(self): def test_tokenizer_decode(self):
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h") tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
@@ -470,3 +502,55 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
def test_pretrained_model_lists(self): def test_pretrained_model_lists(self):
# Wav2Vec2Model has no max model length => no testing # Wav2Vec2Model has no max model length => no testing
pass pass
# overwrite from test_tokenization_common
def test_add_tokens_tokenizer(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
vocab_size = tokenizer.vocab_size
all_size = len(tokenizer)
self.assertNotEqual(vocab_size, 0)
# We usually have added tokens from the start in tests because our vocab fixtures are
# smaller than the original vocabs - let's not assert this
# self.assertEqual(vocab_size, all_size)
new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"]
added_toks = tokenizer.add_tokens(new_toks)
vocab_size_2 = tokenizer.vocab_size
all_size_2 = len(tokenizer)
self.assertNotEqual(vocab_size_2, 0)
self.assertEqual(vocab_size, vocab_size_2)
self.assertEqual(added_toks, len(new_toks))
self.assertEqual(all_size_2, all_size + len(new_toks))
tokens = tokenizer.encode("aaaaa bbbbbb low cccccccccdddddddd l", add_special_tokens=False)
self.assertGreaterEqual(len(tokens), 4)
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
self.assertGreater(tokens[-3], tokenizer.vocab_size - 1)
new_toks_2 = {"eos_token": ">>>>|||<||<<|<<", "pad_token": "<<<<<|||>|>>>>|>"}
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
vocab_size_3 = tokenizer.vocab_size
all_size_3 = len(tokenizer)
self.assertNotEqual(vocab_size_3, 0)
self.assertEqual(vocab_size, vocab_size_3)
self.assertEqual(added_toks_2, len(new_toks_2))
self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
tokens = tokenizer.encode(
">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l", add_special_tokens=False
)
self.assertGreaterEqual(len(tokens), 6)
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
self.assertGreater(tokens[0], tokens[1])
self.assertGreater(tokens[-3], tokenizer.vocab_size - 1)
self.assertGreater(tokens[-3], tokens[-4])
self.assertEqual(tokens[0], tokenizer.eos_token_id)
self.assertEqual(tokens[-3], tokenizer.pad_token_id)