[PretrainedFeatureExtractor] + Wav2Vec2FeatureExtractor, Wav2Vec2Processor, Wav2Vec2Tokenizer (#10324)
* push to show * small improvement * small improvement * Update src/transformers/feature_extraction_utils.py * Update src/transformers/feature_extraction_utils.py * implement base * add common tests * make all tests pass for wav2vec2 * make padding work & add more tests * finalize feature extractor utils * add call method to feature extraction * finalize feature processor * finish tokenizer * finish general processor design * finish tests * typo * remove bogus file * finish docstring * add docs * finish docs * small fix * correct docs * save intermediate * load changes * apply changes * apply changes to doc * change tests * apply surajs recommend * final changes * Apply suggestions from code review * fix typo * fix import * correct docstring
This commit is contained in:
committed by
GitHub
parent
9dc7825744
commit
cb38ffcc5e
@@ -23,11 +23,17 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from transformers.models.wav2vec2 import Wav2Vec2Config, Wav2Vec2Tokenizer
|
||||
from transformers import (
|
||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
Wav2Vec2Config,
|
||||
Wav2Vec2CTCTokenizer,
|
||||
Wav2Vec2Tokenizer,
|
||||
)
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import slow
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
@@ -345,3 +351,101 @@ class Wav2Vec2TokenizerTest(unittest.TestCase):
|
||||
# only "layer" feature extraction norm should make use of
|
||||
# attention_mask
|
||||
self.assertEqual(tokenizer.return_attention_mask, config.feat_extract_norm == "layer")
|
||||
|
||||
|
||||
class Wav2Vec2CTCTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer_class = Wav2Vec2CTCTokenizer
|
||||
test_rust_tokenizer = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
vocab = "<pad> <s> </s> <unk> | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ")
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
|
||||
self.special_tokens_map = {"pad_token": "<pad>", "unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(vocab_tokens) + "\n")
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return Wav2Vec2CTCTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def test_tokenizer_decode(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
sample_ids = [
|
||||
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98],
|
||||
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77],
|
||||
]
|
||||
tokens = tokenizer.decode(sample_ids[0])
|
||||
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||
self.assertEqual(tokens, batch_tokens[0])
|
||||
self.assertEqual(batch_tokens, ["HELLO<unk>", "BYE BYE<unk>"])
|
||||
|
||||
def test_tokenizer_decode_special(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
sample_ids = [
|
||||
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98],
|
||||
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77],
|
||||
]
|
||||
sample_ids_2 = [
|
||||
[11, 5, 5, 5, 5, 5, 15, 15, 15, tokenizer.pad_token_id, 15, 8, 98],
|
||||
[
|
||||
24,
|
||||
22,
|
||||
5,
|
||||
tokenizer.pad_token_id,
|
||||
tokenizer.pad_token_id,
|
||||
tokenizer.pad_token_id,
|
||||
tokenizer.word_delimiter_token_id,
|
||||
24,
|
||||
22,
|
||||
5,
|
||||
77,
|
||||
tokenizer.word_delimiter_token_id,
|
||||
],
|
||||
]
|
||||
|
||||
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||
batch_tokens_2 = tokenizer.batch_decode(sample_ids_2)
|
||||
self.assertEqual(batch_tokens, batch_tokens_2)
|
||||
self.assertEqual(batch_tokens, ["HELLO<unk>", "BYE BYE<unk>"])
|
||||
|
||||
def test_tokenizer_decode_added_tokens(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
tokenizer.add_tokens(["!", "?"])
|
||||
tokenizer.add_special_tokens({"cls_token": "$$$"})
|
||||
|
||||
sample_ids = [
|
||||
[
|
||||
11,
|
||||
5,
|
||||
15,
|
||||
tokenizer.pad_token_id,
|
||||
15,
|
||||
8,
|
||||
98,
|
||||
32,
|
||||
32,
|
||||
33,
|
||||
tokenizer.word_delimiter_token_id,
|
||||
32,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
34,
|
||||
],
|
||||
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.pad_token_id, 34, 34],
|
||||
]
|
||||
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||
|
||||
self.assertEqual(batch_tokens, ["HELLO<unk>!?!?$$$", "BYE BYE<unk>$$$"])
|
||||
|
||||
def test_pretrained_model_lists(self):
|
||||
# Wav2Vec2Model has no max model length => no
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user