Update kwargs validation for preprocess with decorator (#32024)
* BLIP preprocess * BIT preprocess * BRIDGETOWER preprocess * CHAMELEON preprocess * CHINESE_CLIP preprocess * CONVNEXT preprocess * DEIT preprocess * DONUT preprocess * DPT preprocess * FLAVA preprocess * EFFICIENTNET preprocess * FUYU preprocess * GLPN preprocess * IMAGEGPT preprocess * INTRUCTBLIPVIDEO preprocess * VIVIT preprocess * ZOEDEPTH preprocess * VITMATTE preprocess * VIT preprocess * VILT preprocess * VIDEOMAE preprocess * VIDEOLLAVA * TVP processing * TVP fixup * SWIN2SR preprocess * SIGLIP preprocess * SAM preprocess * RT-DETR preprocess * PVT preprocess * POOLFORMER preprocess * PERCEIVER preprocess * OWLVIT preprocess * OWLV2 preprocess * NOUGAT preprocess * MOBILEVIT preprocess * MOBILENETV2 preprocess * MOBILENETV1 preprocess * LEVIT preprocess * LAYOUTLMV2 preprocess * LAYOUTLMV3 preprocess * Add test * Update tests
This commit is contained in:
committed by
GitHub
parent
e85d86398a
commit
fb66ef8147
@@ -15,6 +15,7 @@
|
||||
|
||||
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -51,6 +52,7 @@ class VitMatteImageProcessingTester(unittest.TestCase):
|
||||
image_mean=[0.5, 0.5, 0.5],
|
||||
image_std=[0.5, 0.5, 0.5],
|
||||
):
|
||||
super().__init__()
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
@@ -197,3 +199,20 @@ class VitMatteImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image = np.random.randn(3, 249, 512)
|
||||
images = image_processing.pad_image(image)
|
||||
assert images.shape == (3, 256, 512)
|
||||
|
||||
def test_image_processor_preprocess_arguments(self):
|
||||
# vitmatte require additional trimap input for image_processor
|
||||
# that is why we override original common test
|
||||
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class(**self.image_processor_dict)
|
||||
image = self.image_processor_tester.prepare_image_inputs()[0]
|
||||
trimap = np.random.randint(0, 3, size=image.size[::-1])
|
||||
|
||||
with warnings.catch_warnings(record=True) as raised_warnings:
|
||||
warnings.simplefilter("always")
|
||||
image_processor(image, trimaps=trimap, extra_argument=True)
|
||||
|
||||
messages = " ".join([str(w.message) for w in raised_warnings])
|
||||
self.assertGreaterEqual(len(raised_warnings), 1)
|
||||
self.assertIn("extra_argument", messages)
|
||||
|
||||
Reference in New Issue
Block a user