From 15cd68713d8d027e1033906bf39e999a24b5b5dd Mon Sep 17 00:00:00 2001 From: "Minsub Lee (Matt)" Date: Tue, 2 Apr 2024 23:55:11 +0900 Subject: [PATCH] Fix `skip_special_tokens` for `Wav2Vec2CTCTokenizer._decode` (#29311) * Fix skip_special_tokens process for Wav2Vec2CTCTokenizer._decode * Fix skip_special_tokens for Wav2Vec2CTCTokenizer._decode * Exclude pad_token filtering since it is used as CTC-blank token * Add small test for skip_special_tokens * Update decoding test for added new token --- .../models/wav2vec2/tokenization_wav2vec2.py | 9 ++++++--- tests/models/wav2vec2/test_tokenization_wav2vec2.py | 13 +++++++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py index 42b1aa3063..34848a841e 100644 --- a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py +++ b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py @@ -113,7 +113,6 @@ class Wav2Vec2CTCTokenizerOutput(ModelOutput): class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): - """ Constructs a Wav2Vec2CTC tokenizer. @@ -420,7 +419,9 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): result = [] for token in filtered_tokens: - if skip_special_tokens and token in self.all_special_ids: + if skip_special_tokens and ( + token in self.all_special_ids or (token != self.pad_token and token in self.all_special_tokens) + ): continue result.append(token) @@ -881,7 +882,9 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer): result = [] for token in filtered_tokens: - if skip_special_tokens and token in self.all_special_ids: + if skip_special_tokens and ( + token in self.all_special_ids or (token != self.pad_token and token in self.all_special_tokens) + ): continue result.append(token) diff --git a/tests/models/wav2vec2/test_tokenization_wav2vec2.py b/tests/models/wav2vec2/test_tokenization_wav2vec2.py index 05109f9736..6c98e0e0c8 100644 --- a/tests/models/wav2vec2/test_tokenization_wav2vec2.py +++ b/tests/models/wav2vec2/test_tokenization_wav2vec2.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for the Wav2Vec2 tokenizer.""" + import inspect import json import os @@ -144,8 +145,10 @@ class Wav2Vec2TokenizerTest(unittest.TestCase): [24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.pad_token_id, 34, 34], ] batch_tokens = tokenizer.batch_decode(sample_ids) + batch_tokens_2 = tokenizer.batch_decode(sample_ids, skip_special_tokens=True) self.assertEqual(batch_tokens, ["HELLO!?!?$$$", "BYE BYE$$$"]) + self.assertEqual(batch_tokens_2, ["HELO!?!?", "BYE BYE"]) def test_call(self): # Tests that all call wrap to encode_plus and batch_encode_plus @@ -452,18 +455,20 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase): def test_tokenizer_decode_added_tokens(self): tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h") - tokenizer.add_tokens(["!", "?"]) + tokenizer.add_tokens(["!", "?", ""]) tokenizer.add_special_tokens({"cls_token": "$$$"}) # fmt: off sample_ids = [ - [11, 5, 15, tokenizer.pad_token_id, 15, 8, 98, 32, 32, 33, tokenizer.word_delimiter_token_id, 32, 32, 33, 34, 34], - [24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.pad_token_id, 34, 34], + [11, 5, 15, tokenizer.pad_token_id, 15, 8, 98, 32, 32, 33, tokenizer.word_delimiter_token_id, 32, 32, 33, 34, 34, 35, 35], + [24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.pad_token_id, 34, 34, 35, 35], ] # fmt: on batch_tokens = tokenizer.batch_decode(sample_ids) + batch_tokens_2 = tokenizer.batch_decode(sample_ids, skip_special_tokens=True) - self.assertEqual(batch_tokens, ["HELLO!?!?$$$", "BYE BYE$$$"]) + self.assertEqual(batch_tokens, ["HELLO!?!?$$$", "BYE BYE$$$"]) + self.assertEqual(batch_tokens_2, ["HELO!?!?", "BYE BYE"]) def test_special_characters_in_vocab(self): sent = "ʈʰ æ æ̃ ˧ kʰ"