Fix skip_special_tokens for Wav2Vec2CTCTokenizer._decode (#29311)
* Fix skip_special_tokens process for Wav2Vec2CTCTokenizer._decode * Fix skip_special_tokens for Wav2Vec2CTCTokenizer._decode * Exclude pad_token filtering since it is used as CTC-blank token * Add small test for skip_special_tokens * Update decoding test for added new token
This commit is contained in:
committed by
GitHub
parent
cb5927ca8f
commit
15cd68713d
@@ -113,7 +113,6 @@ class Wav2Vec2CTCTokenizerOutput(ModelOutput):
|
|||||||
|
|
||||||
|
|
||||||
class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Constructs a Wav2Vec2CTC tokenizer.
|
Constructs a Wav2Vec2CTC tokenizer.
|
||||||
|
|
||||||
@@ -420,7 +419,9 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
result = []
|
result = []
|
||||||
for token in filtered_tokens:
|
for token in filtered_tokens:
|
||||||
if skip_special_tokens and token in self.all_special_ids:
|
if skip_special_tokens and (
|
||||||
|
token in self.all_special_ids or (token != self.pad_token and token in self.all_special_tokens)
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
result.append(token)
|
result.append(token)
|
||||||
|
|
||||||
@@ -881,7 +882,9 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
result = []
|
result = []
|
||||||
for token in filtered_tokens:
|
for token in filtered_tokens:
|
||||||
if skip_special_tokens and token in self.all_special_ids:
|
if skip_special_tokens and (
|
||||||
|
token in self.all_special_ids or (token != self.pad_token and token in self.all_special_tokens)
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
result.append(token)
|
result.append(token)
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tests for the Wav2Vec2 tokenizer."""
|
"""Tests for the Wav2Vec2 tokenizer."""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@@ -144,8 +145,10 @@ class Wav2Vec2TokenizerTest(unittest.TestCase):
|
|||||||
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.pad_token_id, 34, 34],
|
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.pad_token_id, 34, 34],
|
||||||
]
|
]
|
||||||
batch_tokens = tokenizer.batch_decode(sample_ids)
|
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||||
|
batch_tokens_2 = tokenizer.batch_decode(sample_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
self.assertEqual(batch_tokens, ["HELLO<unk>!?!?$$$", "BYE BYE<unk>$$$"])
|
self.assertEqual(batch_tokens, ["HELLO<unk>!?!?$$$", "BYE BYE<unk>$$$"])
|
||||||
|
self.assertEqual(batch_tokens_2, ["HELO!?!?", "BYE BYE"])
|
||||||
|
|
||||||
def test_call(self):
|
def test_call(self):
|
||||||
# Tests that all call wrap to encode_plus and batch_encode_plus
|
# Tests that all call wrap to encode_plus and batch_encode_plus
|
||||||
@@ -452,18 +455,20 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def test_tokenizer_decode_added_tokens(self):
|
def test_tokenizer_decode_added_tokens(self):
|
||||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
tokenizer.add_tokens(["!", "?"])
|
tokenizer.add_tokens(["!", "?", "<new_tokens>"])
|
||||||
tokenizer.add_special_tokens({"cls_token": "$$$"})
|
tokenizer.add_special_tokens({"cls_token": "$$$"})
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
sample_ids = [
|
sample_ids = [
|
||||||
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98, 32, 32, 33, tokenizer.word_delimiter_token_id, 32, 32, 33, 34, 34],
|
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98, 32, 32, 33, tokenizer.word_delimiter_token_id, 32, 32, 33, 34, 34, 35, 35],
|
||||||
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.pad_token_id, 34, 34],
|
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.pad_token_id, 34, 34, 35, 35],
|
||||||
]
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
batch_tokens = tokenizer.batch_decode(sample_ids)
|
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||||
|
batch_tokens_2 = tokenizer.batch_decode(sample_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
self.assertEqual(batch_tokens, ["HELLO<unk>!?!?$$$", "BYE BYE<unk>$$$"])
|
self.assertEqual(batch_tokens, ["HELLO<unk>!?!?<new_tokens>$$$", "BYE BYE<unk><new_tokens>$$$"])
|
||||||
|
self.assertEqual(batch_tokens_2, ["HELO!?!?<new_tokens>", "BYE BYE<new_tokens>"])
|
||||||
|
|
||||||
def test_special_characters_in_vocab(self):
|
def test_special_characters_in_vocab(self):
|
||||||
sent = "ʈʰ æ æ̃ ˧ kʰ"
|
sent = "ʈʰ æ æ̃ ˧ kʰ"
|
||||||
|
|||||||
Reference in New Issue
Block a user