[processor] Add 'model input names' property (#20117)
* [processor] Add 'model input names' property * add test * no f string * add generic property method to mixin * copy to multimodal * copy to vision * tests for all audio * remove ad-hoc tests * style * fix flava test * fix test * fix processor code
This commit is contained in:
@@ -19,6 +19,8 @@ import tempfile
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
|
||||
from transformers.models.layoutxlm import LayoutXLMTokenizer, LayoutXLMTokenizerFast
|
||||
from transformers.testing_utils import (
|
||||
@@ -74,6 +76,17 @@ class LayoutXLMProcessorTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def prepare_image_inputs(self):
|
||||
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
|
||||
or a list of PyTorch tensors if one specifies torchify=True.
|
||||
"""
|
||||
|
||||
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
|
||||
|
||||
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
|
||||
|
||||
return image_inputs
|
||||
|
||||
def test_save_load_pretrained_default(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizers = self.get_tokenizers()
|
||||
@@ -126,6 +139,20 @@ class LayoutXLMProcessorTest(unittest.TestCase):
|
||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
||||
self.assertIsInstance(processor.feature_extractor, LayoutLMv2FeatureExtractor)
|
||||
|
||||
def test_model_input_names(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = LayoutXLMProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
# add extra args
|
||||
inputs = processor(text=input_str, images=image_input, return_codebook_pixels=False, return_image_mask=False)
|
||||
|
||||
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
|
||||
|
||||
@slow
|
||||
def test_overflowing_tokens(self):
|
||||
# In the case of overflowing tokens, test that we still have 1-to-1 mapping between the images and input_ids (sequences that are too long are broken down into multiple sequences).
|
||||
|
||||
Reference in New Issue
Block a user