From 880154d2e1e2bc22c9cc8b829b49971acd6e14f1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 22 Apr 2021 12:23:08 +0200 Subject: [PATCH] [Wav2Vec2] Fix special tokens for Wav2Vec2 tokenizer (#11349) * fix wav2vec2 tok * up --- .../models/wav2vec2/tokenization_wav2vec2.py | 7 ++++++ tests/test_tokenization_wav2vec2.py | 22 ++++++++++++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py index 841a7b317f..56ec7a92e2 100644 --- a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py +++ b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py @@ -142,6 +142,12 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): self.encoder = json.load(vocab_handle) self.decoder = {v: k for k, v in self.encoder.items()} + # make sure that tokens made of several + # characters are not split at tokenization + for token in self.encoder.keys(): + if len(token) > 1: + self.unique_no_split_tokens.append(token) + @property def word_delimiter_token(self) -> str: """ @@ -366,6 +372,7 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer): with open(vocab_file, encoding="utf-8") as vocab_handle: self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} @property diff --git a/tests/test_tokenization_wav2vec2.py b/tests/test_tokenization_wav2vec2.py index 002bf4b225..7823de28e0 100644 --- a/tests/test_tokenization_wav2vec2.py +++ b/tests/test_tokenization_wav2vec2.py @@ -447,6 +447,26 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase): self.assertEqual(batch_tokens, ["HELLO!?!?$$$", "BYE BYE$$$"]) + def test_special_characters_in_vocab(self): + sent = "ʈʰ æ æ̃ ˧ kʰ" + + vocab_dict = {k: v for v, k in enumerate({phoneme for phoneme in sent.split()})} + vocab_file = os.path.join(self.tmpdirname, "vocab_special.json") + + with open(vocab_file, "w") as f: + json.dump(vocab_dict, f) + + tokenizer = Wav2Vec2CTCTokenizer(vocab_file) + + expected_sent = tokenizer.decode(tokenizer(sent).input_ids, spaces_between_special_tokens=True) + self.assertEqual(sent, expected_sent) + + tokenizer.save_pretrained(os.path.join(self.tmpdirname, "special_tokenizer")) + tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(os.path.join(self.tmpdirname, "special_tokenizer")) + + expected_sent = tokenizer.decode(tokenizer(sent).input_ids, spaces_between_special_tokens=True) + self.assertEqual(sent, expected_sent) + def test_pretrained_model_lists(self): - # Wav2Vec2Model has no max model length => no + # Wav2Vec2Model has no max model length => no testing pass