[Wav2Vec2] Fix special tokens for Wav2Vec2 tokenizer (#11349)
* fix wav2vec2 tok * up
This commit is contained in:
committed by
GitHub
parent
6f14eab50b
commit
880154d2e1
@@ -447,6 +447,26 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertEqual(batch_tokens, ["HELLO<unk>!?!?$$$", "BYE BYE<unk>$$$"])
|
||||
|
||||
def test_special_characters_in_vocab(self):
|
||||
sent = "ʈʰ æ æ̃ ˧ kʰ"
|
||||
|
||||
vocab_dict = {k: v for v, k in enumerate({phoneme for phoneme in sent.split()})}
|
||||
vocab_file = os.path.join(self.tmpdirname, "vocab_special.json")
|
||||
|
||||
with open(vocab_file, "w") as f:
|
||||
json.dump(vocab_dict, f)
|
||||
|
||||
tokenizer = Wav2Vec2CTCTokenizer(vocab_file)
|
||||
|
||||
expected_sent = tokenizer.decode(tokenizer(sent).input_ids, spaces_between_special_tokens=True)
|
||||
self.assertEqual(sent, expected_sent)
|
||||
|
||||
tokenizer.save_pretrained(os.path.join(self.tmpdirname, "special_tokenizer"))
|
||||
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(os.path.join(self.tmpdirname, "special_tokenizer"))
|
||||
|
||||
expected_sent = tokenizer.decode(tokenizer(sent).input_ids, spaces_between_special_tokens=True)
|
||||
self.assertEqual(sent, expected_sent)
|
||||
|
||||
def test_pretrained_model_lists(self):
|
||||
# Wav2Vec2Model has no max model length => no
|
||||
# Wav2Vec2Model has no max model length => no testing
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user