[processor] Add 'model input names' property (#20117)
* [processor] Add 'model input names' property * add test * no f string * add generic property method to mixin * copy to multimodal * copy to vision * tests for all audio * remove ad-hoc tests * style * fix flava test * fix test * fix processor code
This commit is contained in:
@@ -105,3 +105,9 @@ class CLIPProcessor(ProcessorMixin):
|
|||||||
the docstring of this method for more information.
|
the docstring of this method for more information.
|
||||||
"""
|
"""
|
||||||
return self.tokenizer.decode(*args, **kwargs)
|
return self.tokenizer.decode(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_input_names(self):
|
||||||
|
tokenizer_input_names = self.tokenizer.model_input_names
|
||||||
|
feature_extractor_input_names = self.feature_extractor.model_input_names
|
||||||
|
return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))
|
||||||
|
|||||||
@@ -122,3 +122,9 @@ class FlavaProcessor(ProcessorMixin):
|
|||||||
the docstring of this method for more information.
|
the docstring of this method for more information.
|
||||||
"""
|
"""
|
||||||
return self.tokenizer.decode(*args, **kwargs)
|
return self.tokenizer.decode(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_input_names(self):
|
||||||
|
tokenizer_input_names = self.tokenizer.model_input_names
|
||||||
|
feature_extractor_input_names = self.feature_extractor.model_input_names
|
||||||
|
return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))
|
||||||
|
|||||||
@@ -158,3 +158,7 @@ class LayoutLMv2Processor(ProcessorMixin):
|
|||||||
to the docstring of this method for more information.
|
to the docstring of this method for more information.
|
||||||
"""
|
"""
|
||||||
return self.tokenizer.decode(*args, **kwargs)
|
return self.tokenizer.decode(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_input_names(self):
|
||||||
|
return ["input_ids", "bbox", "token_type_ids", "attention_mask", "image"]
|
||||||
|
|||||||
@@ -156,3 +156,7 @@ class LayoutLMv3Processor(ProcessorMixin):
|
|||||||
to the docstring of this method for more information.
|
to the docstring of this method for more information.
|
||||||
"""
|
"""
|
||||||
return self.tokenizer.decode(*args, **kwargs)
|
return self.tokenizer.decode(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_input_names(self):
|
||||||
|
return ["input_ids", "bbox", "attention_mask", "pixel_values"]
|
||||||
|
|||||||
@@ -158,3 +158,7 @@ class LayoutXLMProcessor(ProcessorMixin):
|
|||||||
to the docstring of this method for more information.
|
to the docstring of this method for more information.
|
||||||
"""
|
"""
|
||||||
return self.tokenizer.decode(*args, **kwargs)
|
return self.tokenizer.decode(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_input_names(self):
|
||||||
|
return ["input_ids", "bbox", "attention_mask", "image"]
|
||||||
|
|||||||
@@ -138,3 +138,8 @@ class MarkupLMProcessor(ProcessorMixin):
|
|||||||
docstring of this method for more information.
|
docstring of this method for more information.
|
||||||
"""
|
"""
|
||||||
return self.tokenizer.decode(*args, **kwargs)
|
return self.tokenizer.decode(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_input_names(self):
|
||||||
|
tokenizer_input_names = self.tokenizer.model_input_names
|
||||||
|
return tokenizer_input_names
|
||||||
|
|||||||
@@ -159,3 +159,9 @@ class OwlViTProcessor(ProcessorMixin):
|
|||||||
the docstring of this method for more information.
|
the docstring of this method for more information.
|
||||||
"""
|
"""
|
||||||
return self.tokenizer.decode(*args, **kwargs)
|
return self.tokenizer.decode(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_input_names(self):
|
||||||
|
tokenizer_input_names = self.tokenizer.model_input_names
|
||||||
|
feature_extractor_input_names = self.feature_extractor.model_input_names
|
||||||
|
return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))
|
||||||
|
|||||||
@@ -106,3 +106,9 @@ class ViltProcessor(ProcessorMixin):
|
|||||||
the docstring of this method for more information.
|
the docstring of this method for more information.
|
||||||
"""
|
"""
|
||||||
return self.tokenizer.decode(*args, **kwargs)
|
return self.tokenizer.decode(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_input_names(self):
|
||||||
|
tokenizer_input_names = self.tokenizer.model_input_names
|
||||||
|
feature_extractor_input_names = self.feature_extractor.model_input_names
|
||||||
|
return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))
|
||||||
|
|||||||
@@ -127,6 +127,11 @@ class VisionTextDualEncoderProcessor(ProcessorMixin):
|
|||||||
return self.tokenizer.decode(*args, **kwargs)
|
return self.tokenizer.decode(*args, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
def model_input_names(self):
|
||||||
|
tokenizer_input_names = self.tokenizer.model_input_names
|
||||||
|
image_processor_input_names = self.image_processor.model_input_names
|
||||||
|
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||||
|
|
||||||
def feature_extractor_class(self):
|
def feature_extractor_class(self):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"`feature_extractor_class` is deprecated and will be removed in v4.27. Use `image_processor_class`"
|
"`feature_extractor_class` is deprecated and will be removed in v4.27. Use `image_processor_class`"
|
||||||
|
|||||||
@@ -107,3 +107,7 @@ class XCLIPProcessor(ProcessorMixin):
|
|||||||
the docstring of this method for more information.
|
the docstring of this method for more information.
|
||||||
"""
|
"""
|
||||||
return self.tokenizer.decode(*args, **kwargs)
|
return self.tokenizer.decode(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_input_names(self):
|
||||||
|
return ["input_ids", "attention_mask", "position_ids", "pixel_values"]
|
||||||
|
|||||||
@@ -227,6 +227,11 @@ class ProcessorMixin(PushToHubMixin):
|
|||||||
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
|
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_input_names(self):
|
||||||
|
first_attribute = getattr(self, self.attributes[0])
|
||||||
|
return getattr(first_attribute, "model_input_names", None)
|
||||||
|
|
||||||
|
|
||||||
ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub)
|
ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub)
|
||||||
ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format(
|
ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format(
|
||||||
|
|||||||
@@ -187,3 +187,16 @@ class CLIPProcessorTest(unittest.TestCase):
|
|||||||
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
||||||
|
|
||||||
self.assertListEqual(decoded_tok, decoded_processor)
|
self.assertListEqual(decoded_tok, decoded_processor)
|
||||||
|
|
||||||
|
def test_model_input_names(self):
|
||||||
|
feature_extractor = self.get_feature_extractor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
processor = CLIPProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||||
|
|
||||||
|
input_str = "lower newer"
|
||||||
|
image_input = self.prepare_image_inputs()
|
||||||
|
|
||||||
|
inputs = processor(text=input_str, images=image_input)
|
||||||
|
|
||||||
|
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
|
||||||
|
|||||||
@@ -231,3 +231,16 @@ class FlavaProcessorTest(unittest.TestCase):
|
|||||||
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
||||||
|
|
||||||
self.assertListEqual(decoded_tok, decoded_processor)
|
self.assertListEqual(decoded_tok, decoded_processor)
|
||||||
|
|
||||||
|
def test_model_input_names(self):
|
||||||
|
feature_extractor = self.get_feature_extractor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
processor = FlavaProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||||
|
|
||||||
|
input_str = "lower newer"
|
||||||
|
image_input = self.prepare_image_inputs()
|
||||||
|
|
||||||
|
inputs = processor(text=input_str, images=image_input)
|
||||||
|
|
||||||
|
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
|
||||||
from transformers.models.layoutlmv2 import LayoutLMv2Tokenizer, LayoutLMv2TokenizerFast
|
from transformers.models.layoutlmv2 import LayoutLMv2Tokenizer, LayoutLMv2TokenizerFast
|
||||||
from transformers.models.layoutlmv2.tokenization_layoutlmv2 import VOCAB_FILES_NAMES
|
from transformers.models.layoutlmv2.tokenization_layoutlmv2 import VOCAB_FILES_NAMES
|
||||||
@@ -86,6 +88,17 @@ class LayoutLMv2ProcessorTest(unittest.TestCase):
|
|||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
shutil.rmtree(self.tmpdirname)
|
shutil.rmtree(self.tmpdirname)
|
||||||
|
|
||||||
|
def prepare_image_inputs(self):
|
||||||
|
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
|
||||||
|
or a list of PyTorch tensors if one specifies torchify=True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
|
||||||
|
|
||||||
|
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
|
||||||
|
|
||||||
|
return image_inputs
|
||||||
|
|
||||||
def test_save_load_pretrained_default(self):
|
def test_save_load_pretrained_default(self):
|
||||||
feature_extractor = self.get_feature_extractor()
|
feature_extractor = self.get_feature_extractor()
|
||||||
tokenizers = self.get_tokenizers()
|
tokenizers = self.get_tokenizers()
|
||||||
@@ -133,6 +146,20 @@ class LayoutLMv2ProcessorTest(unittest.TestCase):
|
|||||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
||||||
self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor)
|
self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor)
|
||||||
|
|
||||||
|
def test_model_input_names(self):
|
||||||
|
feature_extractor = self.get_feature_extractor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
processor = LayoutLMv2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||||
|
|
||||||
|
input_str = "lower newer"
|
||||||
|
image_input = self.prepare_image_inputs()
|
||||||
|
|
||||||
|
# add extra args
|
||||||
|
inputs = processor(text=input_str, images=image_input, return_codebook_pixels=False, return_image_mask=False)
|
||||||
|
|
||||||
|
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_overflowing_tokens(self):
|
def test_overflowing_tokens(self):
|
||||||
# In the case of overflowing tokens, test that we still have 1-to-1 mapping between the images and input_ids (sequences that are too long are broken down into multiple sequences).
|
# In the case of overflowing tokens, test that we still have 1-to-1 mapping between the images and input_ids (sequences that are too long are broken down into multiple sequences).
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
|
||||||
from transformers.models.layoutlmv3 import LayoutLMv3Tokenizer, LayoutLMv3TokenizerFast
|
from transformers.models.layoutlmv3 import LayoutLMv3Tokenizer, LayoutLMv3TokenizerFast
|
||||||
from transformers.models.layoutlmv3.tokenization_layoutlmv3 import VOCAB_FILES_NAMES
|
from transformers.models.layoutlmv3.tokenization_layoutlmv3 import VOCAB_FILES_NAMES
|
||||||
@@ -99,6 +101,17 @@ class LayoutLMv3ProcessorTest(unittest.TestCase):
|
|||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
shutil.rmtree(self.tmpdirname)
|
shutil.rmtree(self.tmpdirname)
|
||||||
|
|
||||||
|
def prepare_image_inputs(self):
|
||||||
|
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
|
||||||
|
or a list of PyTorch tensors if one specifies torchify=True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
|
||||||
|
|
||||||
|
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
|
||||||
|
|
||||||
|
return image_inputs
|
||||||
|
|
||||||
def test_save_load_pretrained_default(self):
|
def test_save_load_pretrained_default(self):
|
||||||
feature_extractor = self.get_feature_extractor()
|
feature_extractor = self.get_feature_extractor()
|
||||||
tokenizers = self.get_tokenizers()
|
tokenizers = self.get_tokenizers()
|
||||||
@@ -146,6 +159,20 @@ class LayoutLMv3ProcessorTest(unittest.TestCase):
|
|||||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
||||||
self.assertIsInstance(processor.feature_extractor, LayoutLMv3FeatureExtractor)
|
self.assertIsInstance(processor.feature_extractor, LayoutLMv3FeatureExtractor)
|
||||||
|
|
||||||
|
def test_model_input_names(self):
|
||||||
|
feature_extractor = self.get_feature_extractor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
processor = LayoutLMv3Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||||
|
|
||||||
|
input_str = "lower newer"
|
||||||
|
image_input = self.prepare_image_inputs()
|
||||||
|
|
||||||
|
# add extra args
|
||||||
|
inputs = processor(text=input_str, images=image_input, return_codebook_pixels=False, return_image_mask=False)
|
||||||
|
|
||||||
|
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
|
||||||
|
|
||||||
|
|
||||||
# different use cases tests
|
# different use cases tests
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
|
||||||
from transformers.models.layoutxlm import LayoutXLMTokenizer, LayoutXLMTokenizerFast
|
from transformers.models.layoutxlm import LayoutXLMTokenizer, LayoutXLMTokenizerFast
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
@@ -74,6 +76,17 @@ class LayoutXLMProcessorTest(unittest.TestCase):
|
|||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
shutil.rmtree(self.tmpdirname)
|
shutil.rmtree(self.tmpdirname)
|
||||||
|
|
||||||
|
def prepare_image_inputs(self):
|
||||||
|
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
|
||||||
|
or a list of PyTorch tensors if one specifies torchify=True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
|
||||||
|
|
||||||
|
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
|
||||||
|
|
||||||
|
return image_inputs
|
||||||
|
|
||||||
def test_save_load_pretrained_default(self):
|
def test_save_load_pretrained_default(self):
|
||||||
feature_extractor = self.get_feature_extractor()
|
feature_extractor = self.get_feature_extractor()
|
||||||
tokenizers = self.get_tokenizers()
|
tokenizers = self.get_tokenizers()
|
||||||
@@ -126,6 +139,20 @@ class LayoutXLMProcessorTest(unittest.TestCase):
|
|||||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
||||||
self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor)
|
self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor)
|
||||||
|
|
||||||
|
def test_model_input_names(self):
|
||||||
|
feature_extractor = self.get_feature_extractor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
processor = LayoutXLMProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||||
|
|
||||||
|
input_str = "lower newer"
|
||||||
|
image_input = self.prepare_image_inputs()
|
||||||
|
|
||||||
|
# add extra args
|
||||||
|
inputs = processor(text=input_str, images=image_input, return_codebook_pixels=False, return_image_mask=False)
|
||||||
|
|
||||||
|
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_overflowing_tokens(self):
|
def test_overflowing_tokens(self):
|
||||||
# In the case of overflowing tokens, test that we still have 1-to-1 mapping between the images and input_ids (sequences that are too long are broken down into multiple sequences).
|
# In the case of overflowing tokens, test that we still have 1-to-1 mapping between the images and input_ids (sequences that are too long are broken down into multiple sequences).
|
||||||
|
|||||||
@@ -133,6 +133,18 @@ class MarkupLMProcessorTest(unittest.TestCase):
|
|||||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
||||||
self.assertIsInstance(processor.feature_extractor, MarkupLMFeatureExtractor)
|
self.assertIsInstance(processor.feature_extractor, MarkupLMFeatureExtractor)
|
||||||
|
|
||||||
|
def test_model_input_names(self):
|
||||||
|
feature_extractor = self.get_feature_extractor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
processor = MarkupLMProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
processor.model_input_names,
|
||||||
|
tokenizer.model_input_names,
|
||||||
|
msg="`processor` and `tokenizer` model input names do not match",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# different use cases tests
|
# different use cases tests
|
||||||
@require_bs4
|
@require_bs4
|
||||||
|
|||||||
@@ -144,3 +144,15 @@ class MCTCTProcessorTest(unittest.TestCase):
|
|||||||
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
||||||
|
|
||||||
self.assertListEqual(decoded_tok, decoded_processor)
|
self.assertListEqual(decoded_tok, decoded_processor)
|
||||||
|
|
||||||
|
def test_model_input_names(self):
|
||||||
|
feature_extractor = self.get_feature_extractor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
processor = MCTCTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
processor.model_input_names,
|
||||||
|
feature_extractor.model_input_names,
|
||||||
|
msg="`processor` and `feature_extractor` model input names do not match",
|
||||||
|
)
|
||||||
|
|||||||
@@ -239,3 +239,16 @@ class OwlViTProcessorTest(unittest.TestCase):
|
|||||||
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
||||||
|
|
||||||
self.assertListEqual(decoded_tok, decoded_processor)
|
self.assertListEqual(decoded_tok, decoded_processor)
|
||||||
|
|
||||||
|
def test_model_input_names(self):
|
||||||
|
feature_extractor = self.get_feature_extractor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
processor = OwlViTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||||
|
|
||||||
|
input_str = "lower newer"
|
||||||
|
image_input = self.prepare_image_inputs()
|
||||||
|
|
||||||
|
inputs = processor(text=input_str, images=image_input)
|
||||||
|
|
||||||
|
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
|
||||||
|
|||||||
@@ -144,3 +144,15 @@ class Speech2TextProcessorTest(unittest.TestCase):
|
|||||||
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
||||||
|
|
||||||
self.assertListEqual(decoded_tok, decoded_processor)
|
self.assertListEqual(decoded_tok, decoded_processor)
|
||||||
|
|
||||||
|
def test_model_input_names(self):
|
||||||
|
feature_extractor = self.get_feature_extractor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
processor = Speech2TextProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
processor.model_input_names,
|
||||||
|
feature_extractor.model_input_names,
|
||||||
|
msg="`processor` and `feature_extractor` model input names do not match",
|
||||||
|
)
|
||||||
|
|||||||
@@ -168,3 +168,16 @@ class VisionTextDualEncoderProcessorTest(unittest.TestCase):
|
|||||||
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
||||||
|
|
||||||
self.assertListEqual(decoded_tok, decoded_processor)
|
self.assertListEqual(decoded_tok, decoded_processor)
|
||||||
|
|
||||||
|
def test_model_input_names(self):
|
||||||
|
feature_extractor = self.get_image_processor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||||
|
|
||||||
|
input_str = "lower newer"
|
||||||
|
image_input = self.prepare_image_inputs()
|
||||||
|
|
||||||
|
inputs = processor(text=input_str, images=image_input)
|
||||||
|
|
||||||
|
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
|
||||||
|
|||||||
@@ -137,3 +137,15 @@ class Wav2Vec2ProcessorTest(unittest.TestCase):
|
|||||||
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
||||||
|
|
||||||
self.assertListEqual(decoded_tok, decoded_processor)
|
self.assertListEqual(decoded_tok, decoded_processor)
|
||||||
|
|
||||||
|
def test_model_input_names(self):
|
||||||
|
feature_extractor = self.get_feature_extractor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
processor.model_input_names,
|
||||||
|
feature_extractor.model_input_names,
|
||||||
|
msg="`processor` and `feature_extractor` model input names do not match",
|
||||||
|
)
|
||||||
|
|||||||
@@ -367,6 +367,19 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertListEqual(decoded_wav2vec2.text, decoded_auto.text)
|
self.assertListEqual(decoded_wav2vec2.text, decoded_auto.text)
|
||||||
|
|
||||||
|
def test_model_input_names(self):
|
||||||
|
feature_extractor = self.get_feature_extractor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
decoder = self.get_decoder()
|
||||||
|
|
||||||
|
processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder)
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
processor.model_input_names,
|
||||||
|
feature_extractor.model_input_names,
|
||||||
|
msg="`processor` and `feature_extractor` model input names do not match",
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_from_offsets(offsets, key):
|
def get_from_offsets(offsets, key):
|
||||||
retrieved_list = [d[key] for d in offsets]
|
retrieved_list = [d[key] for d in offsets]
|
||||||
|
|||||||
@@ -116,3 +116,15 @@ class WhisperProcessorTest(unittest.TestCase):
|
|||||||
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
||||||
|
|
||||||
self.assertListEqual(decoded_tok, decoded_processor)
|
self.assertListEqual(decoded_tok, decoded_processor)
|
||||||
|
|
||||||
|
def test_model_input_names(self):
|
||||||
|
feature_extractor = self.get_feature_extractor()
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
processor = WhisperProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
processor.model_input_names,
|
||||||
|
feature_extractor.model_input_names,
|
||||||
|
msg="`processor` and `feature_extractor` model input names do not match",
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user