From a721a5eefd69cee08329e1fa434c6300b35cd33a Mon Sep 17 00:00:00 2001 From: Muktan Date: Mon, 3 May 2021 20:49:12 +0530 Subject: [PATCH] [Wav2vec2] Fixed tokenization mistakes while adding single-char tokens to tokenizer (#11538) * Fixed tokenization mistakes while adding single-char tokens to tokenizer * Added tests and Removed unnecessary comments. * finalize wav2vec2 tok * add more aggressive tests * Apply suggestions from code review * fix useless import Co-authored-by: Patrick von Platen --- .../models/wav2vec2/tokenization_wav2vec2.py | 59 ++++++++++++- tests/test_tokenization_wav2vec2.py | 84 +++++++++++++++++++ 2 files changed, 141 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py index 56ec7a92e2..e6d1092b1e 100644 --- a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py +++ b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py @@ -24,8 +24,8 @@ from typing import Dict, List, Optional, Tuple, Union import numpy as np from ...file_utils import PaddingStrategy, TensorType, add_end_docstrings -from ...tokenization_utils import PreTrainedTokenizer -from ...tokenization_utils_base import BatchEncoding +from ...tokenization_utils import PreTrainedTokenizer, _insert_one_token_to_ordered_list +from ...tokenization_utils_base import AddedToken, BatchEncoding from ...utils import logging @@ -277,6 +277,61 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): return (vocab_file,) + def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int: + """ + Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to + it with indices starting from length of the current vocabulary. + + Args: + new_tokens (:obj:`List[str]`or :obj:`List[tokenizers.AddedToken]`): + Token(s) to add in vocabulary. A token is only added if it's not already in the vocabulary (tested by + checking if the tokenizer assign the index of the ``unk_token`` to them). + special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the tokens should be added as special tokens. + + Returns: + :obj:`int`: The number of tokens actually added to the vocabulary. + + Examples:: + + # Let's see how to increase the vocabulary of Bert model and tokenizer + tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('facebook/wav2vec2-base-960h') + model = Wav2Vec2ForCTC.from_pretrained('facebook/wav2vec2-base-960h') + + num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2']) + print('We have added', num_added_toks, 'tokens') + # Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer. + model.resize_token_embeddings(len(tokenizer)) + """ + new_tokens = [str(tok) for tok in new_tokens] + + tokens_to_add = [] + for token in new_tokens: + assert isinstance(token, str) + if not special_tokens and hasattr(self, "do_lower_case") and self.do_lower_case: + token = token.lower() + if ( + token != self.unk_token + and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) + and token not in tokens_to_add + ): + tokens_to_add.append(token) + if self.verbose: + logger.info(f"Adding {token} to the vocabulary") + + added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add)) + added_tok_decoder = {v: k for k, v in added_tok_encoder.items()} + self.added_tokens_encoder.update(added_tok_encoder) + self.added_tokens_decoder.update(added_tok_decoder) + + # Make sure we don't split on any special tokens (even they were already in the vocab before) + for token in tokens_to_add: + if len(token) > 1: + self._additional_special_tokens.append(AddedToken(token)) + _insert_one_token_to_ordered_list(self.unique_no_split_tokens, token) + + return len(tokens_to_add) + class Wav2Vec2Tokenizer(PreTrainedTokenizer): """ diff --git a/tests/test_tokenization_wav2vec2.py b/tests/test_tokenization_wav2vec2.py index 7823de28e0..e5336f1f6a 100644 --- a/tests/test_tokenization_wav2vec2.py +++ b/tests/test_tokenization_wav2vec2.py @@ -375,6 +375,38 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase): kwargs.update(self.special_tokens_map) return Wav2Vec2CTCTokenizer.from_pretrained(self.tmpdirname, **kwargs) + def test_tokenizer_add_token_chars(self): + tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h") + + # check adding a single token + tokenizer.add_tokens("x") + token_ids = tokenizer("C x A").input_ids + self.assertEqual(token_ids, [19, 4, 32, 4, 7]) + + tokenizer.add_tokens(["a", "b", "c"]) + token_ids = tokenizer("C a A c").input_ids + self.assertEqual(token_ids, [19, 4, 33, 4, 7, 4, 35]) + + tokenizer.add_tokens(["a", "b", "c"]) + token_ids = tokenizer("CaA c").input_ids + self.assertEqual(token_ids, [19, 33, 7, 4, 35]) + + def test_tokenizer_add_token_words(self): + tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h") + + # check adding a single token + tokenizer.add_tokens("xxx") + token_ids = tokenizer("C xxx A B").input_ids + self.assertEqual(token_ids, [19, 4, 32, 4, 7, 4, 24]) + + tokenizer.add_tokens(["aaa", "bbb", "ccc"]) + token_ids = tokenizer("C aaa A ccc B B").input_ids + self.assertEqual(token_ids, [19, 4, 33, 4, 7, 4, 35, 4, 24, 4, 24]) + + tokenizer.add_tokens(["aaa", "bbb", "ccc"]) + token_ids = tokenizer("CaaaA ccc B B").input_ids + self.assertEqual(token_ids, [19, 33, 7, 4, 35, 4, 24, 4, 24]) + def test_tokenizer_decode(self): tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h") @@ -470,3 +502,55 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase): def test_pretrained_model_lists(self): # Wav2Vec2Model has no max model length => no testing pass + + # overwrite from test_tokenization_common + def test_add_tokens_tokenizer(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + vocab_size = tokenizer.vocab_size + all_size = len(tokenizer) + + self.assertNotEqual(vocab_size, 0) + + # We usually have added tokens from the start in tests because our vocab fixtures are + # smaller than the original vocabs - let's not assert this + # self.assertEqual(vocab_size, all_size) + + new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"] + added_toks = tokenizer.add_tokens(new_toks) + vocab_size_2 = tokenizer.vocab_size + all_size_2 = len(tokenizer) + + self.assertNotEqual(vocab_size_2, 0) + self.assertEqual(vocab_size, vocab_size_2) + self.assertEqual(added_toks, len(new_toks)) + self.assertEqual(all_size_2, all_size + len(new_toks)) + + tokens = tokenizer.encode("aaaaa bbbbbb low cccccccccdddddddd l", add_special_tokens=False) + + self.assertGreaterEqual(len(tokens), 4) + self.assertGreater(tokens[0], tokenizer.vocab_size - 1) + self.assertGreater(tokens[-3], tokenizer.vocab_size - 1) + + new_toks_2 = {"eos_token": ">>>>|||<||<<|<<", "pad_token": "<<<<<|||>|>>>>|>"} + added_toks_2 = tokenizer.add_special_tokens(new_toks_2) + vocab_size_3 = tokenizer.vocab_size + all_size_3 = len(tokenizer) + + self.assertNotEqual(vocab_size_3, 0) + self.assertEqual(vocab_size, vocab_size_3) + self.assertEqual(added_toks_2, len(new_toks_2)) + self.assertEqual(all_size_3, all_size_2 + len(new_toks_2)) + + tokens = tokenizer.encode( + ">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l", add_special_tokens=False + ) + + self.assertGreaterEqual(len(tokens), 6) + self.assertGreater(tokens[0], tokenizer.vocab_size - 1) + self.assertGreater(tokens[0], tokens[1]) + self.assertGreater(tokens[-3], tokenizer.vocab_size - 1) + self.assertGreater(tokens[-3], tokens[-4]) + self.assertEqual(tokens[0], tokenizer.eos_token_id) + self.assertEqual(tokens[-3], tokenizer.pad_token_id)