Update kwargs validation for preprocess with decorator (#32024)

* BLIP preprocess

* BIT preprocess

* BRIDGETOWER preprocess

* CHAMELEON preprocess

* CHINESE_CLIP preprocess

* CONVNEXT preprocess

* DEIT preprocess

* DONUT preprocess

* DPT preprocess

* FLAVA preprocess

* EFFICIENTNET preprocess

* FUYU preprocess

* GLPN preprocess

* IMAGEGPT preprocess

* INTRUCTBLIPVIDEO preprocess

* VIVIT preprocess

* ZOEDEPTH preprocess

* VITMATTE preprocess

* VIT preprocess

* VILT preprocess

* VIDEOMAE preprocess

* VIDEOLLAVA

* TVP processing

* TVP fixup

* SWIN2SR preprocess

* SIGLIP preprocess

* SAM preprocess

* RT-DETR preprocess

* PVT preprocess

* POOLFORMER preprocess

* PERCEIVER preprocess

* OWLVIT preprocess

* OWLV2 preprocess

* NOUGAT preprocess

* MOBILEVIT preprocess

* MOBILENETV2 preprocess

* MOBILENETV1 preprocess

* LEVIT preprocess

* LAYOUTLMV2 preprocess

* LAYOUTLMV3 preprocess

* Add test

* Update tests
This commit is contained in:
Pavel Iakubovskii
2024-08-06 11:33:05 +01:00
committed by GitHub
parent e85d86398a
commit fb66ef8147
76 changed files with 189 additions and 826 deletions

View File

@@ -15,6 +15,7 @@
import unittest
import warnings
import numpy as np
@@ -51,6 +52,7 @@ class VitMatteImageProcessingTester(unittest.TestCase):
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
super().__init__()
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
@@ -197,3 +199,20 @@ class VitMatteImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image = np.random.randn(3, 249, 512)
images = image_processing.pad_image(image)
assert images.shape == (3, 256, 512)
def test_image_processor_preprocess_arguments(self):
# vitmatte require additional trimap input for image_processor
# that is why we override original common test
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class(**self.image_processor_dict)
image = self.image_processor_tester.prepare_image_inputs()[0]
trimap = np.random.randint(0, 3, size=image.size[::-1])
with warnings.catch_warnings(record=True) as raised_warnings:
warnings.simplefilter("always")
image_processor(image, trimaps=trimap, extra_argument=True)
messages = " ".join([str(w.message) for w in raised_warnings])
self.assertGreaterEqual(len(raised_warnings), 1)
self.assertIn("extra_argument", messages)