[Wav2Vec2] Fix special tokens for Wav2Vec2 tokenizer (#11349)
* fix wav2vec2 tok * up
This commit is contained in:
committed by
GitHub
parent
6f14eab50b
commit
880154d2e1
@@ -142,6 +142,12 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
||||
self.encoder = json.load(vocab_handle)
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
|
||||
# make sure that tokens made of several
|
||||
# characters are not split at tokenization
|
||||
for token in self.encoder.keys():
|
||||
if len(token) > 1:
|
||||
self.unique_no_split_tokens.append(token)
|
||||
|
||||
@property
|
||||
def word_delimiter_token(self) -> str:
|
||||
"""
|
||||
@@ -366,6 +372,7 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
||||
|
||||
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
||||
self.encoder = json.load(vocab_handle)
|
||||
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
|
||||
@property
|
||||
|
||||
@@ -447,6 +447,26 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertEqual(batch_tokens, ["HELLO<unk>!?!?$$$", "BYE BYE<unk>$$$"])
|
||||
|
||||
def test_special_characters_in_vocab(self):
|
||||
sent = "ʈʰ æ æ̃ ˧ kʰ"
|
||||
|
||||
vocab_dict = {k: v for v, k in enumerate({phoneme for phoneme in sent.split()})}
|
||||
vocab_file = os.path.join(self.tmpdirname, "vocab_special.json")
|
||||
|
||||
with open(vocab_file, "w") as f:
|
||||
json.dump(vocab_dict, f)
|
||||
|
||||
tokenizer = Wav2Vec2CTCTokenizer(vocab_file)
|
||||
|
||||
expected_sent = tokenizer.decode(tokenizer(sent).input_ids, spaces_between_special_tokens=True)
|
||||
self.assertEqual(sent, expected_sent)
|
||||
|
||||
tokenizer.save_pretrained(os.path.join(self.tmpdirname, "special_tokenizer"))
|
||||
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(os.path.join(self.tmpdirname, "special_tokenizer"))
|
||||
|
||||
expected_sent = tokenizer.decode(tokenizer(sent).input_ids, spaces_between_special_tokens=True)
|
||||
self.assertEqual(sent, expected_sent)
|
||||
|
||||
def test_pretrained_model_lists(self):
|
||||
# Wav2Vec2Model has no max model length => no
|
||||
# Wav2Vec2Model has no max model length => no testing
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user