Uniformize model processors (#31368)
* add initial design for uniform processors + align model * add uniform processors for altclip + chinese_clip * add uniform processors for blip + blip2 * fix mutable default 👀 * add configuration test * handle structured kwargs w defaults + add test * protect torch-specific test * fix style * fix * rebase * update processor to generic kwargs + test * fix style * add sensible kwargs merge * update test * fix assertEqual * move kwargs merging to processing common * rework kwargs for type hinting * just get Unpack from extensions * run-slow[align] * handle kwargs passed as nested dict * add from_pretrained test for nested kwargs handling * [run-slow]align * update documentation + imports * update audio inputs * protect audio types, silly * try removing imports * make things simpler * simplerer * move out kwargs test to common mixin * [run-slow]align * skip tests for old processors * [run-slow]align, clip * !$#@!! protect imports, darn it * [run-slow]align, clip * [run-slow]align, clip * update common processor testing * add altclip * add chinese_clip * add pad_size * [run-slow]align, clip, chinese_clip, altclip * remove duplicated tests * fix * add blip, blip2, bridgetower Added tests for bridgetower which override common. Also modified common tests to force center cropping if existing * fix * update doc * improve documentation for default values * add model_max_length testing This parameter depends on tokenizers received. * Raise if kwargs are specified in two places * fix * removed copied from * match defaults * force padding * fix tokenizer test * clean defaults * move tests to common * add missing import * fix * adapt bridgetower tests to shortest edge * uniformize donut processor + tests * add wav2vec2 * extend common testing to audio processors * add testing + bert version * propagate common kwargs to different modalities * BC order of arguments * check py version * revert kwargs merging * add draft overlap test * update * fix blip2 and wav2vec due to updates * fix copies * ensure overlapping kwargs do not disappear * replace .pop by .get to handle duplicated kwargs * fix copies * fix missing import * add clearly wav2vec2_bert to uniformized models * fix copies * increase number of features * fix style * [run-slow] blip, blip2, bridgetower, donut, wav2vec2, wav2vec2_bert * [run-slow] blip, blip_2, bridgetower, donut, wav2vec2, wav2vec2_bert * fix concatenation * [run-slow] blip, blip_2, bridgetower, donut, wav2vec2, wav2vec2_bert * Update tests/test_processing_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * 🧹 * address comments * clean up + tests * [run-slow] instructblip, blip, blip_2, bridgetower, donut, wav2vec2, wav2vec2_bert --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -18,14 +18,19 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||
from transformers.utils import FEATURE_EXTRACTOR_NAME
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
from .test_feature_extraction_wav2vec2 import floats_list
|
||||
|
||||
|
||||
class Wav2Vec2ProcessorTest(unittest.TestCase):
|
||||
class Wav2Vec2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = Wav2Vec2Processor
|
||||
|
||||
def setUp(self):
|
||||
vocab = "<pad> <s> </s> <unk> | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ")
|
||||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||
@@ -53,6 +58,9 @@ class Wav2Vec2ProcessorTest(unittest.TestCase):
|
||||
with open(self.feature_extraction_file, "w", encoding="utf-8") as fp:
|
||||
fp.write(json.dumps(feature_extractor_map) + "\n")
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs_init):
|
||||
kwargs = self.add_kwargs_tokens_map.copy()
|
||||
kwargs.update(kwargs_init)
|
||||
@@ -117,7 +125,6 @@ class Wav2Vec2ProcessorTest(unittest.TestCase):
|
||||
processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
input_str = "This is a test string"
|
||||
|
||||
encoded_processor = processor(text=input_str)
|
||||
|
||||
encoded_tok = tokenizer(input_str)
|
||||
@@ -125,6 +132,22 @@ class Wav2Vec2ProcessorTest(unittest.TestCase):
|
||||
for key in encoded_tok.keys():
|
||||
self.assertListEqual(encoded_tok[key], encoded_processor[key])
|
||||
|
||||
def test_padding_argument_not_ignored(self):
|
||||
# padding, or any other overlap arg between audio extractor and tokenizer
|
||||
# should be passed to both text and audio and not ignored
|
||||
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
batch_duration_in_seconds = [1, 3, 2, 6]
|
||||
input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds]
|
||||
|
||||
# padding = True should not raise an error and will if the audio processor popped its value to None
|
||||
_ = processor(
|
||||
input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt"
|
||||
)
|
||||
|
||||
def test_tokenizer_decode(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
Reference in New Issue
Block a user