[Wav2Vec2] Fix special tokens for Wav2Vec2 tokenizer (#11349)
* fix wav2vec2 tok * up
This commit is contained in:
committed by
GitHub
parent
6f14eab50b
commit
880154d2e1
@@ -142,6 +142,12 @@ class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
|
|||||||
self.encoder = json.load(vocab_handle)
|
self.encoder = json.load(vocab_handle)
|
||||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||||
|
|
||||||
|
# make sure that tokens made of several
|
||||||
|
# characters are not split at tokenization
|
||||||
|
for token in self.encoder.keys():
|
||||||
|
if len(token) > 1:
|
||||||
|
self.unique_no_split_tokens.append(token)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def word_delimiter_token(self) -> str:
|
def word_delimiter_token(self) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -366,6 +372,7 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
||||||
self.encoder = json.load(vocab_handle)
|
self.encoder = json.load(vocab_handle)
|
||||||
|
|
||||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -447,6 +447,26 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(batch_tokens, ["HELLO<unk>!?!?$$$", "BYE BYE<unk>$$$"])
|
self.assertEqual(batch_tokens, ["HELLO<unk>!?!?$$$", "BYE BYE<unk>$$$"])
|
||||||
|
|
||||||
|
def test_special_characters_in_vocab(self):
|
||||||
|
sent = "ʈʰ æ æ̃ ˧ kʰ"
|
||||||
|
|
||||||
|
vocab_dict = {k: v for v, k in enumerate({phoneme for phoneme in sent.split()})}
|
||||||
|
vocab_file = os.path.join(self.tmpdirname, "vocab_special.json")
|
||||||
|
|
||||||
|
with open(vocab_file, "w") as f:
|
||||||
|
json.dump(vocab_dict, f)
|
||||||
|
|
||||||
|
tokenizer = Wav2Vec2CTCTokenizer(vocab_file)
|
||||||
|
|
||||||
|
expected_sent = tokenizer.decode(tokenizer(sent).input_ids, spaces_between_special_tokens=True)
|
||||||
|
self.assertEqual(sent, expected_sent)
|
||||||
|
|
||||||
|
tokenizer.save_pretrained(os.path.join(self.tmpdirname, "special_tokenizer"))
|
||||||
|
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(os.path.join(self.tmpdirname, "special_tokenizer"))
|
||||||
|
|
||||||
|
expected_sent = tokenizer.decode(tokenizer(sent).input_ids, spaces_between_special_tokens=True)
|
||||||
|
self.assertEqual(sent, expected_sent)
|
||||||
|
|
||||||
def test_pretrained_model_lists(self):
|
def test_pretrained_model_lists(self):
|
||||||
# Wav2Vec2Model has no max model length => no
|
# Wav2Vec2Model has no max model length => no testing
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user