From 905e5773a3a0f3d6e487a55eaa10457c2d7644a7 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Thu, 10 Nov 2022 19:29:20 +0000 Subject: [PATCH] [processor] Add 'model input names' property (#20117) * [processor] Add 'model input names' property * add test * no f string * add generic property method to mixin * copy to multimodal * copy to vision * tests for all audio * remove ad-hoc tests * style * fix flava test * fix test * fix processor code --- .../models/clip/processing_clip.py | 6 +++++ .../models/flava/processing_flava.py | 6 +++++ .../layoutlmv2/processing_layoutlmv2.py | 4 +++ .../layoutlmv3/processing_layoutlmv3.py | 4 +++ .../models/layoutxlm/processing_layoutxlm.py | 4 +++ .../models/markuplm/processing_markuplm.py | 5 ++++ .../models/owlvit/processing_owlvit.py | 6 +++++ .../models/vilt/processing_vilt.py | 6 +++++ .../processing_vision_text_dual_encoder.py | 5 ++++ .../models/x_clip/processing_x_clip.py | 4 +++ src/transformers/processing_utils.py | 5 ++++ tests/models/clip/test_processor_clip.py | 13 +++++++++ tests/models/flava/test_processor_flava.py | 13 +++++++++ .../layoutlmv2/test_processor_layoutlmv2.py | 27 +++++++++++++++++++ .../layoutlmv3/test_processor_layoutlmv3.py | 27 +++++++++++++++++++ .../layoutxlm/test_processor_layoutxlm.py | 27 +++++++++++++++++++ .../markuplm/test_processor_markuplm.py | 12 +++++++++ tests/models/mctct/test_processor_mctct.py | 12 +++++++++ tests/models/owlvit/test_processor_owlvit.py | 13 +++++++++ .../test_processor_speech_to_text.py | 12 +++++++++ ...test_processor_vision_text_dual_encoder.py | 13 +++++++++ .../wav2vec2/test_processor_wav2vec2.py | 12 +++++++++ .../test_processor_wav2vec2_with_lm.py | 13 +++++++++ .../models/whisper/test_processor_whisper.py | 12 +++++++++ 24 files changed, 261 insertions(+) diff --git a/src/transformers/models/clip/processing_clip.py b/src/transformers/models/clip/processing_clip.py index 56dad3b817..4353b33795 100644 --- a/src/transformers/models/clip/processing_clip.py +++ b/src/transformers/models/clip/processing_clip.py @@ -105,3 +105,9 @@ class CLIPProcessor(ProcessorMixin): the docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + feature_extractor_input_names = self.feature_extractor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names)) diff --git a/src/transformers/models/flava/processing_flava.py b/src/transformers/models/flava/processing_flava.py index ca2fa094a8..043befb3d6 100644 --- a/src/transformers/models/flava/processing_flava.py +++ b/src/transformers/models/flava/processing_flava.py @@ -122,3 +122,9 @@ class FlavaProcessor(ProcessorMixin): the docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + feature_extractor_input_names = self.feature_extractor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names)) diff --git a/src/transformers/models/layoutlmv2/processing_layoutlmv2.py b/src/transformers/models/layoutlmv2/processing_layoutlmv2.py index 57f0b78aed..73b361f42a 100644 --- a/src/transformers/models/layoutlmv2/processing_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/processing_layoutlmv2.py @@ -158,3 +158,7 @@ class LayoutLMv2Processor(ProcessorMixin): to the docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + return ["input_ids", "bbox", "token_type_ids", "attention_mask", "image"] diff --git a/src/transformers/models/layoutlmv3/processing_layoutlmv3.py b/src/transformers/models/layoutlmv3/processing_layoutlmv3.py index f254cd5b55..763710d285 100644 --- a/src/transformers/models/layoutlmv3/processing_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/processing_layoutlmv3.py @@ -156,3 +156,7 @@ class LayoutLMv3Processor(ProcessorMixin): to the docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + return ["input_ids", "bbox", "attention_mask", "pixel_values"] diff --git a/src/transformers/models/layoutxlm/processing_layoutxlm.py b/src/transformers/models/layoutxlm/processing_layoutxlm.py index b382c03c5f..49fbb1ac3d 100644 --- a/src/transformers/models/layoutxlm/processing_layoutxlm.py +++ b/src/transformers/models/layoutxlm/processing_layoutxlm.py @@ -158,3 +158,7 @@ class LayoutXLMProcessor(ProcessorMixin): to the docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + return ["input_ids", "bbox", "attention_mask", "image"] diff --git a/src/transformers/models/markuplm/processing_markuplm.py b/src/transformers/models/markuplm/processing_markuplm.py index 86ed8dc7ee..d6251586ac 100644 --- a/src/transformers/models/markuplm/processing_markuplm.py +++ b/src/transformers/models/markuplm/processing_markuplm.py @@ -138,3 +138,8 @@ class MarkupLMProcessor(ProcessorMixin): docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + return tokenizer_input_names diff --git a/src/transformers/models/owlvit/processing_owlvit.py b/src/transformers/models/owlvit/processing_owlvit.py index 48060f0dcf..707fa47690 100644 --- a/src/transformers/models/owlvit/processing_owlvit.py +++ b/src/transformers/models/owlvit/processing_owlvit.py @@ -159,3 +159,9 @@ class OwlViTProcessor(ProcessorMixin): the docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + feature_extractor_input_names = self.feature_extractor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names)) diff --git a/src/transformers/models/vilt/processing_vilt.py b/src/transformers/models/vilt/processing_vilt.py index 7410666d8c..24b94d0c57 100644 --- a/src/transformers/models/vilt/processing_vilt.py +++ b/src/transformers/models/vilt/processing_vilt.py @@ -106,3 +106,9 @@ class ViltProcessor(ProcessorMixin): the docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + feature_extractor_input_names = self.feature_extractor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names)) diff --git a/src/transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py b/src/transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py index 5466df4fc7..94a1b5b913 100644 --- a/src/transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +++ b/src/transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py @@ -127,6 +127,11 @@ class VisionTextDualEncoderProcessor(ProcessorMixin): return self.tokenizer.decode(*args, **kwargs) @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + def feature_extractor_class(self): warnings.warn( "`feature_extractor_class` is deprecated and will be removed in v4.27. Use `image_processor_class`" diff --git a/src/transformers/models/x_clip/processing_x_clip.py b/src/transformers/models/x_clip/processing_x_clip.py index 7e694a3e33..6717175c65 100644 --- a/src/transformers/models/x_clip/processing_x_clip.py +++ b/src/transformers/models/x_clip/processing_x_clip.py @@ -107,3 +107,7 @@ class XCLIPProcessor(ProcessorMixin): the docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + return ["input_ids", "attention_mask", "position_ids", "pixel_values"] diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 40f0e6a01e..027e669a3f 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -227,6 +227,11 @@ class ProcessorMixin(PushToHubMixin): args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs)) return args + @property + def model_input_names(self): + first_attribute = getattr(self, self.attributes[0]) + return getattr(first_attribute, "model_input_names", None) + ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub) ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format( diff --git a/tests/models/clip/test_processor_clip.py b/tests/models/clip/test_processor_clip.py index 51a0236b90..6cfd5c0261 100644 --- a/tests/models/clip/test_processor_clip.py +++ b/tests/models/clip/test_processor_clip.py @@ -187,3 +187,16 @@ class CLIPProcessorTest(unittest.TestCase): decoded_tok = tokenizer.batch_decode(predicted_ids) self.assertListEqual(decoded_tok, decoded_processor) + + def test_model_input_names(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = CLIPProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input) + + self.assertListEqual(list(inputs.keys()), processor.model_input_names) diff --git a/tests/models/flava/test_processor_flava.py b/tests/models/flava/test_processor_flava.py index 11b8a8add4..94fcb77df5 100644 --- a/tests/models/flava/test_processor_flava.py +++ b/tests/models/flava/test_processor_flava.py @@ -231,3 +231,16 @@ class FlavaProcessorTest(unittest.TestCase): decoded_tok = tokenizer.batch_decode(predicted_ids) self.assertListEqual(decoded_tok, decoded_processor) + + def test_model_input_names(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = FlavaProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input) + + self.assertListEqual(list(inputs.keys()), processor.model_input_names) diff --git a/tests/models/layoutlmv2/test_processor_layoutlmv2.py b/tests/models/layoutlmv2/test_processor_layoutlmv2.py index 4f686155ad..c1fdde7d7c 100644 --- a/tests/models/layoutlmv2/test_processor_layoutlmv2.py +++ b/tests/models/layoutlmv2/test_processor_layoutlmv2.py @@ -19,6 +19,8 @@ import tempfile import unittest from typing import List +import numpy as np + from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast from transformers.models.layoutlmv2 import LayoutLMv2Tokenizer, LayoutLMv2TokenizerFast from transformers.models.layoutlmv2.tokenization_layoutlmv2 import VOCAB_FILES_NAMES @@ -86,6 +88,17 @@ class LayoutLMv2ProcessorTest(unittest.TestCase): def tearDown(self): shutil.rmtree(self.tmpdirname) + def prepare_image_inputs(self): + """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, + or a list of PyTorch tensors if one specifies torchify=True. + """ + + image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] + + image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] + + return image_inputs + def test_save_load_pretrained_default(self): feature_extractor = self.get_feature_extractor() tokenizers = self.get_tokenizers() @@ -133,6 +146,20 @@ class LayoutLMv2ProcessorTest(unittest.TestCase): self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor) + def test_model_input_names(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = LayoutLMv2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # add extra args + inputs = processor(text=input_str, images=image_input, return_codebook_pixels=False, return_image_mask=False) + + self.assertListEqual(list(inputs.keys()), processor.model_input_names) + @slow def test_overflowing_tokens(self): # In the case of overflowing tokens, test that we still have 1-to-1 mapping between the images and input_ids (sequences that are too long are broken down into multiple sequences). diff --git a/tests/models/layoutlmv3/test_processor_layoutlmv3.py b/tests/models/layoutlmv3/test_processor_layoutlmv3.py index a01b0a00cd..6a0062ed3c 100644 --- a/tests/models/layoutlmv3/test_processor_layoutlmv3.py +++ b/tests/models/layoutlmv3/test_processor_layoutlmv3.py @@ -19,6 +19,8 @@ import tempfile import unittest from typing import List +import numpy as np + from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast from transformers.models.layoutlmv3 import LayoutLMv3Tokenizer, LayoutLMv3TokenizerFast from transformers.models.layoutlmv3.tokenization_layoutlmv3 import VOCAB_FILES_NAMES @@ -99,6 +101,17 @@ class LayoutLMv3ProcessorTest(unittest.TestCase): def tearDown(self): shutil.rmtree(self.tmpdirname) + def prepare_image_inputs(self): + """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, + or a list of PyTorch tensors if one specifies torchify=True. + """ + + image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] + + image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] + + return image_inputs + def test_save_load_pretrained_default(self): feature_extractor = self.get_feature_extractor() tokenizers = self.get_tokenizers() @@ -146,6 +159,20 @@ class LayoutLMv3ProcessorTest(unittest.TestCase): self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) self.assertIsInstance(processor.feature_extractor, LayoutLMv3FeatureExtractor) + def test_model_input_names(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = LayoutLMv3Processor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # add extra args + inputs = processor(text=input_str, images=image_input, return_codebook_pixels=False, return_image_mask=False) + + self.assertListEqual(list(inputs.keys()), processor.model_input_names) + # different use cases tests @require_torch diff --git a/tests/models/layoutxlm/test_processor_layoutxlm.py b/tests/models/layoutxlm/test_processor_layoutxlm.py index 2752bd16a8..2843528bae 100644 --- a/tests/models/layoutxlm/test_processor_layoutxlm.py +++ b/tests/models/layoutxlm/test_processor_layoutxlm.py @@ -19,6 +19,8 @@ import tempfile import unittest from typing import List +import numpy as np + from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast from transformers.models.layoutxlm import LayoutXLMTokenizer, LayoutXLMTokenizerFast from transformers.testing_utils import ( @@ -74,6 +76,17 @@ class LayoutXLMProcessorTest(unittest.TestCase): def tearDown(self): shutil.rmtree(self.tmpdirname) + def prepare_image_inputs(self): + """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, + or a list of PyTorch tensors if one specifies torchify=True. + """ + + image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] + + image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] + + return image_inputs + def test_save_load_pretrained_default(self): feature_extractor = self.get_feature_extractor() tokenizers = self.get_tokenizers() @@ -126,6 +139,20 @@ class LayoutXLMProcessorTest(unittest.TestCase): self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor) + def test_model_input_names(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = LayoutXLMProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # add extra args + inputs = processor(text=input_str, images=image_input, return_codebook_pixels=False, return_image_mask=False) + + self.assertListEqual(list(inputs.keys()), processor.model_input_names) + @slow def test_overflowing_tokens(self): # In the case of overflowing tokens, test that we still have 1-to-1 mapping between the images and input_ids (sequences that are too long are broken down into multiple sequences). diff --git a/tests/models/markuplm/test_processor_markuplm.py b/tests/models/markuplm/test_processor_markuplm.py index 6870a63336..141d7bae18 100644 --- a/tests/models/markuplm/test_processor_markuplm.py +++ b/tests/models/markuplm/test_processor_markuplm.py @@ -133,6 +133,18 @@ class MarkupLMProcessorTest(unittest.TestCase): self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) self.assertIsInstance(processor.feature_extractor, MarkupLMFeatureExtractor) + def test_model_input_names(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = MarkupLMProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + self.assertListEqual( + processor.model_input_names, + tokenizer.model_input_names, + msg="`processor` and `tokenizer` model input names do not match", + ) + # different use cases tests @require_bs4 diff --git a/tests/models/mctct/test_processor_mctct.py b/tests/models/mctct/test_processor_mctct.py index 821e44b48e..306d4b174f 100644 --- a/tests/models/mctct/test_processor_mctct.py +++ b/tests/models/mctct/test_processor_mctct.py @@ -144,3 +144,15 @@ class MCTCTProcessorTest(unittest.TestCase): decoded_tok = tokenizer.batch_decode(predicted_ids) self.assertListEqual(decoded_tok, decoded_processor) + + def test_model_input_names(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = MCTCTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + self.assertListEqual( + processor.model_input_names, + feature_extractor.model_input_names, + msg="`processor` and `feature_extractor` model input names do not match", + ) diff --git a/tests/models/owlvit/test_processor_owlvit.py b/tests/models/owlvit/test_processor_owlvit.py index e37f45b15c..98fd1222e3 100644 --- a/tests/models/owlvit/test_processor_owlvit.py +++ b/tests/models/owlvit/test_processor_owlvit.py @@ -239,3 +239,16 @@ class OwlViTProcessorTest(unittest.TestCase): decoded_tok = tokenizer.batch_decode(predicted_ids) self.assertListEqual(decoded_tok, decoded_processor) + + def test_model_input_names(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = OwlViTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input) + + self.assertListEqual(list(inputs.keys()), processor.model_input_names) diff --git a/tests/models/speech_to_text/test_processor_speech_to_text.py b/tests/models/speech_to_text/test_processor_speech_to_text.py index d519f005d3..9b8b3ccf66 100644 --- a/tests/models/speech_to_text/test_processor_speech_to_text.py +++ b/tests/models/speech_to_text/test_processor_speech_to_text.py @@ -144,3 +144,15 @@ class Speech2TextProcessorTest(unittest.TestCase): decoded_tok = tokenizer.batch_decode(predicted_ids) self.assertListEqual(decoded_tok, decoded_processor) + + def test_model_input_names(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = Speech2TextProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + self.assertListEqual( + processor.model_input_names, + feature_extractor.model_input_names, + msg="`processor` and `feature_extractor` model input names do not match", + ) diff --git a/tests/models/vision_text_dual_encoder/test_processor_vision_text_dual_encoder.py b/tests/models/vision_text_dual_encoder/test_processor_vision_text_dual_encoder.py index c73a44f2c1..30630256f9 100644 --- a/tests/models/vision_text_dual_encoder/test_processor_vision_text_dual_encoder.py +++ b/tests/models/vision_text_dual_encoder/test_processor_vision_text_dual_encoder.py @@ -168,3 +168,16 @@ class VisionTextDualEncoderProcessorTest(unittest.TestCase): decoded_tok = tokenizer.batch_decode(predicted_ids) self.assertListEqual(decoded_tok, decoded_processor) + + def test_model_input_names(self): + feature_extractor = self.get_image_processor() + tokenizer = self.get_tokenizer() + + processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input) + + self.assertListEqual(list(inputs.keys()), processor.model_input_names) diff --git a/tests/models/wav2vec2/test_processor_wav2vec2.py b/tests/models/wav2vec2/test_processor_wav2vec2.py index 5f1c259061..67883618ca 100644 --- a/tests/models/wav2vec2/test_processor_wav2vec2.py +++ b/tests/models/wav2vec2/test_processor_wav2vec2.py @@ -137,3 +137,15 @@ class Wav2Vec2ProcessorTest(unittest.TestCase): decoded_tok = tokenizer.batch_decode(predicted_ids) self.assertListEqual(decoded_tok, decoded_processor) + + def test_model_input_names(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + self.assertListEqual( + processor.model_input_names, + feature_extractor.model_input_names, + msg="`processor` and `feature_extractor` model input names do not match", + ) diff --git a/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py b/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py index 11a45b6e1a..92e185bdc7 100644 --- a/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py +++ b/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py @@ -367,6 +367,19 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): self.assertListEqual(decoded_wav2vec2.text, decoded_auto.text) + def test_model_input_names(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + decoder = self.get_decoder() + + processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder) + + self.assertListEqual( + processor.model_input_names, + feature_extractor.model_input_names, + msg="`processor` and `feature_extractor` model input names do not match", + ) + @staticmethod def get_from_offsets(offsets, key): retrieved_list = [d[key] for d in offsets] diff --git a/tests/models/whisper/test_processor_whisper.py b/tests/models/whisper/test_processor_whisper.py index 00a5995f00..bcdf1fb9f0 100644 --- a/tests/models/whisper/test_processor_whisper.py +++ b/tests/models/whisper/test_processor_whisper.py @@ -116,3 +116,15 @@ class WhisperProcessorTest(unittest.TestCase): decoded_tok = tokenizer.batch_decode(predicted_ids) self.assertListEqual(decoded_tok, decoded_processor) + + def test_model_input_names(self): + feature_extractor = self.get_feature_extractor() + tokenizer = self.get_tokenizer() + + processor = WhisperProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor) + + self.assertListEqual( + processor.model_input_names, + feature_extractor.model_input_names, + msg="`processor` and `feature_extractor` model input names do not match", + )