Update image processor parameters if creating with kwargs (#20866)

* Update parameters if creating with kwargs

* Shallow copy to prevent mutating input

* Pass all args in constructor dict - warnings in init

* Fix typo
This commit is contained in:
amyeroberts
2023-01-04 14:29:48 +00:00
committed by GitHub
parent f9e977be70
commit 292acd71d6
36 changed files with 378 additions and 5 deletions

View File

@@ -316,8 +316,17 @@ class ImageProcessingMixin(PushToHubMixin):
[`~image_processing_utils.ImageProcessingMixin`]: The image processor object instantiated from those [`~image_processing_utils.ImageProcessingMixin`]: The image processor object instantiated from those
parameters. parameters.
""" """
image_processor_dict = image_processor_dict.copy()
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
# The `size` parameter is a dict and was previously an int or tuple in feature extractors.
# We set `size` here directly to the `image_processor_dict` so that it is converted to the appropriate
# dict within the image processor and isn't overwritten if `size` is passed in as a kwarg.
if "size" in kwargs and "size" in image_processor_dict:
image_processor_dict["size"] = kwargs.pop("size")
if "crop_size" in kwargs and "crop_size" in image_processor_dict:
image_processor_dict["crop_size"] = kwargs.pop("crop_size")
image_processor = cls(**image_processor_dict) image_processor = cls(**image_processor_dict)
# Update image_processor with kwargs if needed # Update image_processor with kwargs if needed

View File

@@ -15,7 +15,7 @@
"""Image processor class for Beit.""" """Image processor class for Beit."""
import warnings import warnings
from typing import Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
@@ -131,6 +131,17 @@ class BeitImageProcessor(BaseImageProcessor):
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_reduce_labels = do_reduce_labels self.do_reduce_labels = do_reduce_labels
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure `reduce_labels` is updated if image processor
is created using from_dict and kwargs e.g. `BeitImageProcessor.from_pretrained(checkpoint, reduce_labels=True)`
"""
image_processor_dict = image_processor_dict.copy()
if "reduce_labels" in kwargs:
image_processor_dict["reduce_labels"] = kwargs.pop("reduce_labels")
return super().from_dict(image_processor_dict, **kwargs)
def resize( def resize(
self, self,
image: np.ndarray, image: np.ndarray,

View File

@@ -815,6 +815,21 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_pad = do_pad self.do_pad = do_pad
@classmethod
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.from_dict with Detr->ConditionalDetr
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
created using from_dict and kwargs e.g. `ConditionalDetrImageProcessor.from_pretrained(checkpoint, size=600,
max_size=800)`
"""
image_processor_dict = image_processor_dict.copy()
if "max_size" in kwargs:
image_processor_dict["max_size"] = kwargs.pop("max_size")
if "pad_and_return_pixel_mask" in kwargs:
image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
return super().from_dict(image_processor_dict, **kwargs)
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare_annotation with DETR->ConditionalDetr # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare_annotation with DETR->ConditionalDetr
def prepare_annotation( def prepare_annotation(
self, self,

View File

@@ -813,6 +813,21 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_pad = do_pad self.do_pad = do_pad
@classmethod
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.from_dict with Detr->DeformableDetr
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
created using from_dict and kwargs e.g. `DeformableDetrImageProcessor.from_pretrained(checkpoint, size=600,
max_size=800)`
"""
image_processor_dict = image_processor_dict.copy()
if "max_size" in kwargs:
image_processor_dict["max_size"] = kwargs.pop("max_size")
if "pad_and_return_pixel_mask" in kwargs:
image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
return super().from_dict(image_processor_dict, **kwargs)
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare_annotation with DETR->DeformableDetr # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare_annotation with DETR->DeformableDetr
def prepare_annotation( def prepare_annotation(
self, self,

View File

@@ -797,6 +797,20 @@ class DetrImageProcessor(BaseImageProcessor):
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_pad = do_pad self.do_pad = do_pad
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
created using from_dict and kwargs e.g. `DetrImageProcessor.from_pretrained(checkpoint, size=600,
max_size=800)`
"""
image_processor_dict = image_processor_dict.copy()
if "max_size" in kwargs:
image_processor_dict["max_size"] = kwargs.pop("max_size")
if "pad_and_return_pixel_mask" in kwargs:
image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
return super().from_dict(image_processor_dict, **kwargs)
def prepare_annotation( def prepare_annotation(
self, self,
image: np.ndarray, image: np.ndarray,

View File

@@ -17,7 +17,7 @@
import math import math
import random import random
from functools import lru_cache from functools import lru_cache
from typing import Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
@@ -293,6 +293,19 @@ class FlavaImageProcessor(BaseImageProcessor):
self.codebook_image_mean = codebook_image_mean if codebook_image_mean is not None else FLAVA_CODEBOOK_MEAN self.codebook_image_mean = codebook_image_mean if codebook_image_mean is not None else FLAVA_CODEBOOK_MEAN
self.codebook_image_std = codebook_image_std if codebook_image_std is not None else FLAVA_CODEBOOK_STD self.codebook_image_std = codebook_image_std if codebook_image_std is not None else FLAVA_CODEBOOK_STD
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
created using from_dict and kwargs e.g. `FlavaImageProcessor.from_pretrained(checkpoint, codebook_size=600)`
"""
image_processor_dict = image_processor_dict.copy()
if "codebook_size" in kwargs:
image_processor_dict["codebook_size"] = kwargs.pop("codebook_size")
if "codebook_crop_size" in kwargs:
image_processor_dict["codebook_crop_size"] = kwargs.pop("codebook_crop_size")
return super().from_dict(image_processor_dict, **kwargs)
@lru_cache() @lru_cache()
def masking_generator( def masking_generator(
self, self,

View File

@@ -400,7 +400,7 @@ class MaskFormerImageProcessor(BaseImageProcessor):
if "size_divisibility" in kwargs: if "size_divisibility" in kwargs:
warnings.warn( warnings.warn(
"The `size_divisibility` argument is deprecated and will be removed in v4.27. Please use " "The `size_divisibility` argument is deprecated and will be removed in v4.27. Please use "
"`size_divisibility` instead.", "`size_divisor` instead.",
FutureWarning, FutureWarning,
) )
size_divisor = kwargs.pop("size_divisibility") size_divisor = kwargs.pop("size_divisibility")
@@ -432,6 +432,19 @@ class MaskFormerImageProcessor(BaseImageProcessor):
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.reduce_labels = reduce_labels self.reduce_labels = reduce_labels
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
created using from_dict and kwargs e.g. `MaskFormerImageProcessor.from_pretrained(checkpoint, max_size=800)`
"""
image_processor_dict = image_processor_dict.copy()
if "max_size" in kwargs:
image_processor_dict["max_size"] = kwargs.pop("max_size")
if "size_divisibility" in kwargs:
image_processor_dict["size_divisibility"] = kwargs.pop("size_divisibility")
return super().from_dict(image_processor_dict, **kwargs)
@property @property
def size_divisibility(self): def size_divisibility(self):
warnings.warn( warnings.warn(

View File

@@ -15,7 +15,7 @@
"""Image processor class for Segformer.""" """Image processor class for Segformer."""
import warnings import warnings
from typing import Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
@@ -119,6 +119,18 @@ class SegformerImageProcessor(BaseImageProcessor):
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_reduce_labels = do_reduce_labels self.do_reduce_labels = do_reduce_labels
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure `reduce_labels` is updated if image processor
is created using from_dict and kwargs e.g. `SegformerImageProcessor.from_pretrained(checkpoint,
reduce_labels=True)`
"""
image_processor_dict = image_processor_dict.copy()
if "reduce_labels" in kwargs:
image_processor_dict["reduce_labels"] = kwargs.pop("reduce_labels")
return super().from_dict(image_processor_dict, **kwargs)
def resize( def resize(
self, self,
image: np.ndarray, image: np.ndarray,

View File

@@ -185,6 +185,18 @@ class ViltImageProcessor(BaseImageProcessor):
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_pad = do_pad self.do_pad = do_pad
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure `reduce_labels` is updated if image processor
is created using from_dict and kwargs e.g. `ViltImageProcessor.from_pretrained(checkpoint,
pad_and_return_pixel_mask=False)`
"""
image_processor_dict = image_processor_dict.copy()
if "pad_and_return_pixel_mask" in kwargs:
image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
return super().from_dict(image_processor_dict, **kwargs)
def resize( def resize(
self, self,
image: np.ndarray, image: np.ndarray,

View File

@@ -725,6 +725,21 @@ class YolosImageProcessor(BaseImageProcessor):
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_pad = do_pad self.do_pad = do_pad
@classmethod
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.from_dict with Detr->Yolos
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
created using from_dict and kwargs e.g. `YolosImageProcessor.from_pretrained(checkpoint, size=600,
max_size=800)`
"""
image_processor_dict = image_processor_dict.copy()
if "max_size" in kwargs:
image_processor_dict["max_size"] = kwargs.pop("max_size")
if "pad_and_return_pixel_mask" in kwargs:
image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
return super().from_dict(image_processor_dict, **kwargs)
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare_annotation # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare_annotation
def prepare_annotation( def prepare_annotation(
self, self,

View File

@@ -125,6 +125,19 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
self.assertTrue(hasattr(feature_extractor, "image_mean")) self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std")) self.assertTrue(hasattr(feature_extractor, "image_std"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"height": 20, "width": 20})
self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18})
self.assertEqual(feature_extractor.do_reduce_labels, False)
feature_extractor = self.feature_extraction_class.from_dict(
self.feat_extract_dict, size=42, crop_size=84, reduce_labels=True
)
self.assertEqual(feature_extractor.size, {"height": 42, "width": 42})
self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84})
self.assertEqual(feature_extractor.do_reduce_labels, True)
def test_batch_feature(self): def test_batch_feature(self):
pass pass

View File

@@ -135,6 +135,15 @@ class ChineseCLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittes
self.assertTrue(hasattr(feature_extractor, "image_std")) self.assertTrue(hasattr(feature_extractor, "image_std"))
self.assertTrue(hasattr(feature_extractor, "do_convert_rgb")) self.assertTrue(hasattr(feature_extractor, "do_convert_rgb"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"height": 224, "width": 224})
self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42})
self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84})
def test_batch_feature(self): def test_batch_feature(self):
pass pass

View File

@@ -135,6 +135,15 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
self.assertTrue(hasattr(feature_extractor, "image_std")) self.assertTrue(hasattr(feature_extractor, "image_std"))
self.assertTrue(hasattr(feature_extractor, "do_convert_rgb")) self.assertTrue(hasattr(feature_extractor, "do_convert_rgb"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 20})
self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42})
self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84})
def test_batch_feature(self): def test_batch_feature(self):
pass pass

View File

@@ -133,6 +133,17 @@ class ConditionalDetrFeatureExtractionTest(FeatureExtractionSavingTestMixin, uni
self.assertTrue(hasattr(feature_extractor, "do_resize")) self.assertTrue(hasattr(feature_extractor, "do_resize"))
self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "size"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 18, "longest_edge": 1333})
self.assertEqual(feature_extractor.do_pad, True)
feature_extractor = self.feature_extraction_class.from_dict(
self.feat_extract_dict, size=42, max_size=84, pad_and_return_pixel_mask=False
)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42, "longest_edge": 84})
self.assertEqual(feature_extractor.do_pad, False)
def test_batch_feature(self): def test_batch_feature(self):
pass pass

View File

@@ -96,6 +96,13 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
self.assertTrue(hasattr(feature_extractor, "image_mean")) self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std")) self.assertTrue(hasattr(feature_extractor, "image_std"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 20})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42})
def test_batch_feature(self): def test_batch_feature(self):
pass pass

View File

@@ -135,6 +135,17 @@ class DeformableDetrFeatureExtractionTest(FeatureExtractionSavingTestMixin, unit
self.assertTrue(hasattr(feature_extractor, "do_pad")) self.assertTrue(hasattr(feature_extractor, "do_pad"))
self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "size"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 18, "longest_edge": 1333})
self.assertEqual(feature_extractor.do_pad, True)
feature_extractor = self.feature_extraction_class.from_dict(
self.feat_extract_dict, size=42, max_size=84, pad_and_return_pixel_mask=False
)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42, "longest_edge": 84})
self.assertEqual(feature_extractor.do_pad, False)
def test_batch_feature(self): def test_batch_feature(self):
pass pass

View File

@@ -103,6 +103,15 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
self.assertTrue(hasattr(feature_extractor, "image_mean")) self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std")) self.assertTrue(hasattr(feature_extractor, "image_std"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"height": 20, "width": 20})
self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84)
self.assertEqual(feature_extractor.size, {"height": 42, "width": 42})
self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84})
def test_batch_feature(self): def test_batch_feature(self):
pass pass

View File

@@ -136,6 +136,17 @@ class DetrFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "size"))
self.assertTrue(hasattr(feature_extractor, "do_pad")) self.assertTrue(hasattr(feature_extractor, "do_pad"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 18, "longest_edge": 1333})
self.assertEqual(feature_extractor.do_pad, True)
feature_extractor = self.feature_extraction_class.from_dict(
self.feat_extract_dict, size=42, max_size=84, pad_and_return_pixel_mask=False
)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42, "longest_edge": 84})
self.assertEqual(feature_extractor.do_pad, False)
def test_batch_feature(self): def test_batch_feature(self):
pass pass

View File

@@ -103,6 +103,17 @@ class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
self.assertTrue(hasattr(feature_extractor, "image_mean")) self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std")) self.assertTrue(hasattr(feature_extractor, "image_std"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"height": 18, "width": 20})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42)
self.assertEqual(feature_extractor.size, {"height": 42, "width": 42})
# Previous config had dimensions in (width, height) order
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=(42, 84))
self.assertEqual(feature_extractor.size, {"height": 84, "width": 42})
def test_batch_feature(self): def test_batch_feature(self):
pass pass

View File

@@ -92,6 +92,13 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
self.assertTrue(hasattr(feature_extractor, "do_resize")) self.assertTrue(hasattr(feature_extractor, "do_resize"))
self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "size"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"height": 18, "width": 18})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42)
self.assertEqual(feature_extractor.size, {"height": 42, "width": 42})
def test_call_pil(self): def test_call_pil(self):
# Initialize feature_extractor # Initialize feature_extractor
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)

View File

@@ -193,6 +193,21 @@ class FlavaFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
self.assertTrue(hasattr(feature_extractor, "codebook_image_mean")) self.assertTrue(hasattr(feature_extractor, "codebook_image_mean"))
self.assertTrue(hasattr(feature_extractor, "codebook_image_std")) self.assertTrue(hasattr(feature_extractor, "codebook_image_std"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"height": 224, "width": 224})
self.assertEqual(feature_extractor.crop_size, {"height": 224, "width": 224})
self.assertEqual(feature_extractor.codebook_size, {"height": 112, "width": 112})
self.assertEqual(feature_extractor.codebook_crop_size, {"height": 112, "width": 112})
feature_extractor = self.feature_extraction_class.from_dict(
self.feat_extract_dict, size=42, crop_size=84, codebook_size=33, codebook_crop_size=66
)
self.assertEqual(feature_extractor.size, {"height": 42, "width": 42})
self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84})
self.assertEqual(feature_extractor.codebook_size, {"height": 33, "width": 33})
self.assertEqual(feature_extractor.codebook_crop_size, {"height": 66, "width": 66})
def test_batch_feature(self): def test_batch_feature(self):
pass pass

View File

@@ -96,6 +96,13 @@ class ImageGPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "size"))
self.assertTrue(hasattr(feature_extractor, "do_normalize")) self.assertTrue(hasattr(feature_extractor, "do_normalize"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"height": 18, "width": 18})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42)
self.assertEqual(feature_extractor.size, {"height": 42, "width": 42})
def test_feat_extract_to_json_string(self): def test_feat_extract_to_json_string(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict) feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
obj = json.loads(feat_extract.to_json_string()) obj = json.loads(feat_extract.to_json_string())

View File

@@ -80,6 +80,13 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "size"))
self.assertTrue(hasattr(feature_extractor, "apply_ocr")) self.assertTrue(hasattr(feature_extractor, "apply_ocr"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"height": 18, "width": 18})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42)
self.assertEqual(feature_extractor.size, {"height": 42, "width": 42})
def test_batch_feature(self): def test_batch_feature(self):
pass pass

View File

@@ -80,6 +80,13 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "size"))
self.assertTrue(hasattr(feature_extractor, "apply_ocr")) self.assertTrue(hasattr(feature_extractor, "apply_ocr"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"height": 18, "width": 18})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42)
self.assertEqual(feature_extractor.size, {"height": 42, "width": 42})
def test_batch_feature(self): def test_batch_feature(self):
pass pass

View File

@@ -100,6 +100,15 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
self.assertTrue(hasattr(feature_extractor, "do_center_crop")) self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "size"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 18})
self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42})
self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84})
def test_batch_feature(self): def test_batch_feature(self):
pass pass

View File

@@ -152,6 +152,17 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
self.assertTrue(hasattr(feature_extractor, "ignore_index")) self.assertTrue(hasattr(feature_extractor, "ignore_index"))
self.assertTrue(hasattr(feature_extractor, "num_labels")) self.assertTrue(hasattr(feature_extractor, "num_labels"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 32, "longest_edge": 1333})
self.assertEqual(feature_extractor.size_divisor, 0)
feature_extractor = self.feature_extraction_class.from_dict(
self.feat_extract_dict, size=42, max_size=84, size_divisibility=8
)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42, "longest_edge": 84})
self.assertEqual(feature_extractor.size_divisor, 8)
def test_batch_feature(self): def test_batch_feature(self):
pass pass

View File

@@ -89,6 +89,15 @@ class MobileNetV1FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittes
self.assertTrue(hasattr(feature_extractor, "do_center_crop")) self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
self.assertTrue(hasattr(feature_extractor, "center_crop")) self.assertTrue(hasattr(feature_extractor, "center_crop"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 20})
self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42})
self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84})
def test_batch_feature(self): def test_batch_feature(self):
pass pass

View File

@@ -87,7 +87,16 @@ class MobileNetV2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittes
self.assertTrue(hasattr(feature_extractor, "do_resize")) self.assertTrue(hasattr(feature_extractor, "do_resize"))
self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "size"))
self.assertTrue(hasattr(feature_extractor, "do_center_crop")) self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
self.assertTrue(hasattr(feature_extractor, "center_crop")) self.assertTrue(hasattr(feature_extractor, "crop_size"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 20})
self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42})
self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84})
def test_batch_feature(self): def test_batch_feature(self):
pass pass

View File

@@ -93,6 +93,15 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
self.assertTrue(hasattr(feature_extractor, "center_crop")) self.assertTrue(hasattr(feature_extractor, "center_crop"))
self.assertTrue(hasattr(feature_extractor, "do_flip_channel_order")) self.assertTrue(hasattr(feature_extractor, "do_flip_channel_order"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 20})
self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42})
self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84})
def test_batch_feature(self): def test_batch_feature(self):
pass pass

View File

@@ -103,6 +103,15 @@ class OwlViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Tes
self.assertTrue(hasattr(feature_extractor, "image_std")) self.assertTrue(hasattr(feature_extractor, "image_std"))
self.assertTrue(hasattr(feature_extractor, "do_convert_rgb")) self.assertTrue(hasattr(feature_extractor, "do_convert_rgb"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"height": 18, "width": 18})
self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84)
self.assertEqual(feature_extractor.size, {"height": 42, "width": 42})
self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84})
def test_call_pil(self): def test_call_pil(self):
# Initialize feature_extractor # Initialize feature_extractor
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)

View File

@@ -97,6 +97,15 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
self.assertTrue(hasattr(feature_extractor, "image_mean")) self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std")) self.assertTrue(hasattr(feature_extractor, "image_std"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 30})
self.assertEqual(feature_extractor.crop_size, {"height": 30, "width": 30})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42})
self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84})
def test_batch_feature(self): def test_batch_feature(self):
pass pass

View File

@@ -115,6 +115,17 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.
self.assertTrue(hasattr(feature_extractor, "image_std")) self.assertTrue(hasattr(feature_extractor, "image_std"))
self.assertTrue(hasattr(feature_extractor, "do_reduce_labels")) self.assertTrue(hasattr(feature_extractor, "do_reduce_labels"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"height": 30, "width": 30})
self.assertEqual(feature_extractor.do_reduce_labels, False)
feature_extractor = self.feature_extraction_class.from_dict(
self.feat_extract_dict, size=42, reduce_labels=True
)
self.assertEqual(feature_extractor.size, {"height": 42, "width": 42})
self.assertEqual(feature_extractor.do_reduce_labels, True)
def test_batch_feature(self): def test_batch_feature(self):
pass pass

View File

@@ -100,6 +100,15 @@ class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T
self.assertTrue(hasattr(feature_extractor, "do_center_crop")) self.assertTrue(hasattr(feature_extractor, "do_center_crop"))
self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "size"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 18})
self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42})
self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84})
def test_batch_feature(self): def test_batch_feature(self):
pass pass

View File

@@ -136,6 +136,13 @@ class ViltFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC
self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "size"))
self.assertTrue(hasattr(feature_extractor, "size_divisor")) self.assertTrue(hasattr(feature_extractor, "size_divisor"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 30})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42})
def test_batch_feature(self): def test_batch_feature(self):
pass pass

View File

@@ -92,6 +92,13 @@ class ViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa
self.assertTrue(hasattr(feature_extractor, "do_resize")) self.assertTrue(hasattr(feature_extractor, "do_resize"))
self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "size"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"height": 18, "width": 18})
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42)
self.assertEqual(feature_extractor.size, {"height": 42, "width": 42})
def test_batch_feature(self): def test_batch_feature(self):
pass pass

View File

@@ -133,6 +133,17 @@ class YolosFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test
self.assertTrue(hasattr(feature_extractor, "do_resize")) self.assertTrue(hasattr(feature_extractor, "do_resize"))
self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "size"))
def test_feat_extract_from_dict_with_kwargs(self):
feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict)
self.assertEqual(feature_extractor.size, {"shortest_edge": 18, "longest_edge": 1333})
self.assertEqual(feature_extractor.do_pad, True)
feature_extractor = self.feature_extraction_class.from_dict(
self.feat_extract_dict, size=42, max_size=84, pad_and_return_pixel_mask=False
)
self.assertEqual(feature_extractor.size, {"shortest_edge": 42, "longest_edge": 84})
self.assertEqual(feature_extractor.do_pad, False)
def test_batch_feature(self): def test_batch_feature(self):
pass pass