[processor] Add 'model input names' property (#20117)

* [processor] Add 'model input names' property

* add test

* no f string

* add generic property method to mixin

* copy to multimodal

* copy to vision

* tests for all audio

* remove ad-hoc tests

* style

* fix flava test

* fix test

* fix processor code
This commit is contained in:
Sanchit Gandhi
2022-11-10 19:29:20 +00:00
committed by GitHub
parent 68187c4642
commit 905e5773a3
24 changed files with 261 additions and 0 deletions

View File

@@ -105,3 +105,9 @@ class CLIPProcessor(ProcessorMixin):
the docstring of this method for more information. the docstring of this method for more information.
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
feature_extractor_input_names = self.feature_extractor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))

View File

@@ -122,3 +122,9 @@ class FlavaProcessor(ProcessorMixin):
the docstring of this method for more information. the docstring of this method for more information.
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
feature_extractor_input_names = self.feature_extractor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))

View File

@@ -158,3 +158,7 @@ class LayoutLMv2Processor(ProcessorMixin):
to the docstring of this method for more information. to the docstring of this method for more information.
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
return ["input_ids", "bbox", "token_type_ids", "attention_mask", "image"]

View File

@@ -156,3 +156,7 @@ class LayoutLMv3Processor(ProcessorMixin):
to the docstring of this method for more information. to the docstring of this method for more information.
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
return ["input_ids", "bbox", "attention_mask", "pixel_values"]

View File

@@ -158,3 +158,7 @@ class LayoutXLMProcessor(ProcessorMixin):
to the docstring of this method for more information. to the docstring of this method for more information.
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
return ["input_ids", "bbox", "attention_mask", "image"]

View File

@@ -138,3 +138,8 @@ class MarkupLMProcessor(ProcessorMixin):
docstring of this method for more information. docstring of this method for more information.
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
return tokenizer_input_names

View File

@@ -159,3 +159,9 @@ class OwlViTProcessor(ProcessorMixin):
the docstring of this method for more information. the docstring of this method for more information.
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
feature_extractor_input_names = self.feature_extractor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))

View File

@@ -106,3 +106,9 @@ class ViltProcessor(ProcessorMixin):
the docstring of this method for more information. the docstring of this method for more information.
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
feature_extractor_input_names = self.feature_extractor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))

View File

@@ -127,6 +127,11 @@ class VisionTextDualEncoderProcessor(ProcessorMixin):
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
@property @property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
def feature_extractor_class(self): def feature_extractor_class(self):
warnings.warn( warnings.warn(
"`feature_extractor_class` is deprecated and will be removed in v4.27. Use `image_processor_class`" "`feature_extractor_class` is deprecated and will be removed in v4.27. Use `image_processor_class`"

View File

@@ -107,3 +107,7 @@ class XCLIPProcessor(ProcessorMixin):
the docstring of this method for more information. the docstring of this method for more information.
""" """
return self.tokenizer.decode(*args, **kwargs) return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
return ["input_ids", "attention_mask", "position_ids", "pixel_values"]

View File

@@ -227,6 +227,11 @@ class ProcessorMixin(PushToHubMixin):
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs)) args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
return args return args
@property
def model_input_names(self):
first_attribute = getattr(self, self.attributes[0])
return getattr(first_attribute, "model_input_names", None)
ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub) ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub)
ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format( ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format(

View File

@@ -187,3 +187,16 @@ class CLIPProcessorTest(unittest.TestCase):
decoded_tok = tokenizer.batch_decode(predicted_ids) decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor) self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = CLIPProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
self.assertListEqual(list(inputs.keys()), processor.model_input_names)

View File

@@ -231,3 +231,16 @@ class FlavaProcessorTest(unittest.TestCase):
decoded_tok = tokenizer.batch_decode(predicted_ids) decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor) self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = FlavaProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
self.assertListEqual(list(inputs.keys()), processor.model_input_names)

View File

@@ -19,6 +19,8 @@ import tempfile
import unittest import unittest
from typing import List from typing import List
import numpy as np
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
from transformers.models.layoutlmv2 import LayoutLMv2Tokenizer, LayoutLMv2TokenizerFast from transformers.models.layoutlmv2 import LayoutLMv2Tokenizer, LayoutLMv2TokenizerFast
from transformers.models.layoutlmv2.tokenization_layoutlmv2 import VOCAB_FILES_NAMES from transformers.models.layoutlmv2.tokenization_layoutlmv2 import VOCAB_FILES_NAMES
@@ -86,6 +88,17 @@ class LayoutLMv2ProcessorTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
shutil.rmtree(self.tmpdirname) shutil.rmtree(self.tmpdirname)
def prepare_image_inputs(self):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
"""
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
return image_inputs
def test_save_load_pretrained_default(self): def test_save_load_pretrained_default(self):
feature_extractor = self.get_feature_extractor() feature_extractor = self.get_feature_extractor()
tokenizers = self.get_tokenizers() tokenizers = self.get_tokenizers()
@@ -133,6 +146,20 @@ class LayoutLMv2ProcessorTest(unittest.TestCase):
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor) self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = LayoutLMv2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# add extra args
inputs = processor(text=input_str, images=image_input, return_codebook_pixels=False, return_image_mask=False)
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
@slow @slow
def test_overflowing_tokens(self): def test_overflowing_tokens(self):
# In the case of overflowing tokens, test that we still have 1-to-1 mapping between the images and input_ids (sequences that are too long are broken down into multiple sequences). # In the case of overflowing tokens, test that we still have 1-to-1 mapping between the images and input_ids (sequences that are too long are broken down into multiple sequences).

View File

@@ -19,6 +19,8 @@ import tempfile
import unittest import unittest
from typing import List from typing import List
import numpy as np
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
from transformers.models.layoutlmv3 import LayoutLMv3Tokenizer, LayoutLMv3TokenizerFast from transformers.models.layoutlmv3 import LayoutLMv3Tokenizer, LayoutLMv3TokenizerFast
from transformers.models.layoutlmv3.tokenization_layoutlmv3 import VOCAB_FILES_NAMES from transformers.models.layoutlmv3.tokenization_layoutlmv3 import VOCAB_FILES_NAMES
@@ -99,6 +101,17 @@ class LayoutLMv3ProcessorTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
shutil.rmtree(self.tmpdirname) shutil.rmtree(self.tmpdirname)
def prepare_image_inputs(self):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
"""
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
return image_inputs
def test_save_load_pretrained_default(self): def test_save_load_pretrained_default(self):
feature_extractor = self.get_feature_extractor() feature_extractor = self.get_feature_extractor()
tokenizers = self.get_tokenizers() tokenizers = self.get_tokenizers()
@@ -146,6 +159,20 @@ class LayoutLMv3ProcessorTest(unittest.TestCase):
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, LayoutLMv3FeatureExtractor) self.assertIsInstance(processor.feature_extractor, LayoutLMv3FeatureExtractor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = LayoutLMv3Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# add extra args
inputs = processor(text=input_str, images=image_input, return_codebook_pixels=False, return_image_mask=False)
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
# different use cases tests # different use cases tests
@require_torch @require_torch

View File

@@ -19,6 +19,8 @@ import tempfile
import unittest import unittest
from typing import List from typing import List
import numpy as np
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
from transformers.models.layoutxlm import LayoutXLMTokenizer, LayoutXLMTokenizerFast from transformers.models.layoutxlm import LayoutXLMTokenizer, LayoutXLMTokenizerFast
from transformers.testing_utils import ( from transformers.testing_utils import (
@@ -74,6 +76,17 @@ class LayoutXLMProcessorTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
shutil.rmtree(self.tmpdirname) shutil.rmtree(self.tmpdirname)
def prepare_image_inputs(self):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
"""
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
return image_inputs
def test_save_load_pretrained_default(self): def test_save_load_pretrained_default(self):
feature_extractor = self.get_feature_extractor() feature_extractor = self.get_feature_extractor()
tokenizers = self.get_tokenizers() tokenizers = self.get_tokenizers()
@@ -126,6 +139,20 @@ class LayoutXLMProcessorTest(unittest.TestCase):
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor) self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = LayoutXLMProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
# add extra args
inputs = processor(text=input_str, images=image_input, return_codebook_pixels=False, return_image_mask=False)
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
@slow @slow
def test_overflowing_tokens(self): def test_overflowing_tokens(self):
# In the case of overflowing tokens, test that we still have 1-to-1 mapping between the images and input_ids (sequences that are too long are broken down into multiple sequences). # In the case of overflowing tokens, test that we still have 1-to-1 mapping between the images and input_ids (sequences that are too long are broken down into multiple sequences).

View File

@@ -133,6 +133,18 @@ class MarkupLMProcessorTest(unittest.TestCase):
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string()) self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, MarkupLMFeatureExtractor) self.assertIsInstance(processor.feature_extractor, MarkupLMFeatureExtractor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = MarkupLMProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.assertListEqual(
processor.model_input_names,
tokenizer.model_input_names,
msg="`processor` and `tokenizer` model input names do not match",
)
# different use cases tests # different use cases tests
@require_bs4 @require_bs4

View File

@@ -144,3 +144,15 @@ class MCTCTProcessorTest(unittest.TestCase):
decoded_tok = tokenizer.batch_decode(predicted_ids) decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor) self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = MCTCTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.assertListEqual(
processor.model_input_names,
feature_extractor.model_input_names,
msg="`processor` and `feature_extractor` model input names do not match",
)

View File

@@ -239,3 +239,16 @@ class OwlViTProcessorTest(unittest.TestCase):
decoded_tok = tokenizer.batch_decode(predicted_ids) decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor) self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = OwlViTProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
self.assertListEqual(list(inputs.keys()), processor.model_input_names)

View File

@@ -144,3 +144,15 @@ class Speech2TextProcessorTest(unittest.TestCase):
decoded_tok = tokenizer.batch_decode(predicted_ids) decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor) self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Speech2TextProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.assertListEqual(
processor.model_input_names,
feature_extractor.model_input_names,
msg="`processor` and `feature_extractor` model input names do not match",
)

View File

@@ -168,3 +168,16 @@ class VisionTextDualEncoderProcessorTest(unittest.TestCase):
decoded_tok = tokenizer.batch_decode(predicted_ids) decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor) self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
feature_extractor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = VisionTextDualEncoderProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
self.assertListEqual(list(inputs.keys()), processor.model_input_names)

View File

@@ -137,3 +137,15 @@ class Wav2Vec2ProcessorTest(unittest.TestCase):
decoded_tok = tokenizer.batch_decode(predicted_ids) decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor) self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.assertListEqual(
processor.model_input_names,
feature_extractor.model_input_names,
msg="`processor` and `feature_extractor` model input names do not match",
)

View File

@@ -367,6 +367,19 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
self.assertListEqual(decoded_wav2vec2.text, decoded_auto.text) self.assertListEqual(decoded_wav2vec2.text, decoded_auto.text)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
decoder = self.get_decoder()
processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder)
self.assertListEqual(
processor.model_input_names,
feature_extractor.model_input_names,
msg="`processor` and `feature_extractor` model input names do not match",
)
@staticmethod @staticmethod
def get_from_offsets(offsets, key): def get_from_offsets(offsets, key):
retrieved_list = [d[key] for d in offsets] retrieved_list = [d[key] for d in offsets]

View File

@@ -116,3 +116,15 @@ class WhisperProcessorTest(unittest.TestCase):
decoded_tok = tokenizer.batch_decode(predicted_ids) decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor) self.assertListEqual(decoded_tok, decoded_processor)
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = WhisperProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.assertListEqual(
processor.model_input_names,
feature_extractor.model_input_names,
msg="`processor` and `feature_extractor` model input names do not match",
)