[Wav2Vec2] Fix special tokens for Wav2Vec2 tokenizer (#11349)

* fix wav2vec2 tok

* up
This commit is contained in:
Patrick von Platen
2021-04-22 12:23:08 +02:00
committed by GitHub
parent 6f14eab50b
commit 880154d2e1
2 changed files with 28 additions and 1 deletions

View File

@@ -447,6 +447,26 @@ class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual(batch_tokens, ["HELLO<unk>!?!?$$$", "BYE BYE<unk>$$$"])
def test_special_characters_in_vocab(self):
sent = "ʈʰ æ æ̃ ˧ kʰ"
vocab_dict = {k: v for v, k in enumerate({phoneme for phoneme in sent.split()})}
vocab_file = os.path.join(self.tmpdirname, "vocab_special.json")
with open(vocab_file, "w") as f:
json.dump(vocab_dict, f)
tokenizer = Wav2Vec2CTCTokenizer(vocab_file)
expected_sent = tokenizer.decode(tokenizer(sent).input_ids, spaces_between_special_tokens=True)
self.assertEqual(sent, expected_sent)
tokenizer.save_pretrained(os.path.join(self.tmpdirname, "special_tokenizer"))
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(os.path.join(self.tmpdirname, "special_tokenizer"))
expected_sent = tokenizer.decode(tokenizer(sent).input_ids, spaces_between_special_tokens=True)
self.assertEqual(sent, expected_sent)
def test_pretrained_model_lists(self):
# Wav2Vec2Model has no max model length => no
# Wav2Vec2Model has no max model length => no testing
pass