From 292acd71d6d5305bbc5470351a6bc412d678cdae Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Wed, 4 Jan 2023 14:29:48 +0000 Subject: [PATCH] Update image processor parameters if creating with kwargs (#20866) * Update parameters if creating with kwargs * Shallow copy to prevent mutating input * Pass all args in constructor dict - warnings in init * Fix typo --- src/transformers/image_processing_utils.py | 9 +++++++++ .../models/beit/image_processing_beit.py | 13 ++++++++++++- .../image_processing_conditional_detr.py | 15 +++++++++++++++ .../image_processing_deformable_detr.py | 15 +++++++++++++++ .../models/detr/image_processing_detr.py | 14 ++++++++++++++ .../models/flava/image_processing_flava.py | 15 ++++++++++++++- .../maskformer/image_processing_maskformer.py | 15 ++++++++++++++- .../segformer/image_processing_segformer.py | 14 +++++++++++++- .../models/vilt/image_processing_vilt.py | 12 ++++++++++++ .../models/yolos/image_processing_yolos.py | 15 +++++++++++++++ tests/models/beit/test_feature_extraction_beit.py | 13 +++++++++++++ .../test_feature_extraction_chinese_clip.py | 9 +++++++++ tests/models/clip/test_feature_extraction_clip.py | 9 +++++++++ .../test_feature_extraction_conditional_detr.py | 11 +++++++++++ .../convnext/test_feature_extraction_convnext.py | 7 +++++++ .../test_feature_extraction_deformable_detr.py | 11 +++++++++++ tests/models/deit/test_feature_extraction_deit.py | 9 +++++++++ tests/models/detr/test_feature_extraction_detr.py | 11 +++++++++++ .../models/donut/test_feature_extraction_donut.py | 11 +++++++++++ tests/models/dpt/test_feature_extraction_dpt.py | 7 +++++++ .../models/flava/test_feature_extraction_flava.py | 15 +++++++++++++++ .../imagegpt/test_feature_extraction_imagegpt.py | 7 +++++++ .../test_feature_extraction_layoutlmv2.py | 7 +++++++ .../test_feature_extraction_layoutlmv3.py | 7 +++++++ .../models/levit/test_feature_extraction_levit.py | 9 +++++++++ .../test_feature_extraction_maskformer.py | 11 +++++++++++ .../test_feature_extraction_mobilenet_v1.py | 9 +++++++++ .../test_feature_extraction_mobilenet_v2.py | 11 ++++++++++- .../test_feature_extraction_mobilevit.py | 9 +++++++++ .../owlvit/test_feature_extraction_owlvit.py | 9 +++++++++ .../test_feature_extraction_poolformer.py | 9 +++++++++ .../test_feature_extraction_segformer.py | 11 +++++++++++ .../videomae/test_feature_extraction_videomae.py | 9 +++++++++ tests/models/vilt/test_feature_extraction_vilt.py | 7 +++++++ tests/models/vit/test_feature_extraction_vit.py | 7 +++++++ .../models/yolos/test_feature_extraction_yolos.py | 11 +++++++++++ 36 files changed, 378 insertions(+), 5 deletions(-) diff --git a/src/transformers/image_processing_utils.py b/src/transformers/image_processing_utils.py index c299128b4d..0be7719782 100644 --- a/src/transformers/image_processing_utils.py +++ b/src/transformers/image_processing_utils.py @@ -316,8 +316,17 @@ class ImageProcessingMixin(PushToHubMixin): [`~image_processing_utils.ImageProcessingMixin`]: The image processor object instantiated from those parameters. """ + image_processor_dict = image_processor_dict.copy() return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) + # The `size` parameter is a dict and was previously an int or tuple in feature extractors. + # We set `size` here directly to the `image_processor_dict` so that it is converted to the appropriate + # dict within the image processor and isn't overwritten if `size` is passed in as a kwarg. + if "size" in kwargs and "size" in image_processor_dict: + image_processor_dict["size"] = kwargs.pop("size") + if "crop_size" in kwargs and "crop_size" in image_processor_dict: + image_processor_dict["crop_size"] = kwargs.pop("crop_size") + image_processor = cls(**image_processor_dict) # Update image_processor with kwargs if needed diff --git a/src/transformers/models/beit/image_processing_beit.py b/src/transformers/models/beit/image_processing_beit.py index 405581fafb..0e81cb9c44 100644 --- a/src/transformers/models/beit/image_processing_beit.py +++ b/src/transformers/models/beit/image_processing_beit.py @@ -15,7 +15,7 @@ """Image processor class for Beit.""" import warnings -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -131,6 +131,17 @@ class BeitImageProcessor(BaseImageProcessor): self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD self.do_reduce_labels = do_reduce_labels + @classmethod + def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to make sure `reduce_labels` is updated if image processor + is created using from_dict and kwargs e.g. `BeitImageProcessor.from_pretrained(checkpoint, reduce_labels=True)` + """ + image_processor_dict = image_processor_dict.copy() + if "reduce_labels" in kwargs: + image_processor_dict["reduce_labels"] = kwargs.pop("reduce_labels") + return super().from_dict(image_processor_dict, **kwargs) + def resize( self, image: np.ndarray, diff --git a/src/transformers/models/conditional_detr/image_processing_conditional_detr.py b/src/transformers/models/conditional_detr/image_processing_conditional_detr.py index 5e0ea3bd61..b5f4c639f7 100644 --- a/src/transformers/models/conditional_detr/image_processing_conditional_detr.py +++ b/src/transformers/models/conditional_detr/image_processing_conditional_detr.py @@ -815,6 +815,21 @@ class ConditionalDetrImageProcessor(BaseImageProcessor): self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD self.do_pad = do_pad + @classmethod + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.from_dict with Detr->ConditionalDetr + def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is + created using from_dict and kwargs e.g. `ConditionalDetrImageProcessor.from_pretrained(checkpoint, size=600, + max_size=800)` + """ + image_processor_dict = image_processor_dict.copy() + if "max_size" in kwargs: + image_processor_dict["max_size"] = kwargs.pop("max_size") + if "pad_and_return_pixel_mask" in kwargs: + image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask") + return super().from_dict(image_processor_dict, **kwargs) + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare_annotation with DETR->ConditionalDetr def prepare_annotation( self, diff --git a/src/transformers/models/deformable_detr/image_processing_deformable_detr.py b/src/transformers/models/deformable_detr/image_processing_deformable_detr.py index 188c0e139c..499313dd52 100644 --- a/src/transformers/models/deformable_detr/image_processing_deformable_detr.py +++ b/src/transformers/models/deformable_detr/image_processing_deformable_detr.py @@ -813,6 +813,21 @@ class DeformableDetrImageProcessor(BaseImageProcessor): self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD self.do_pad = do_pad + @classmethod + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.from_dict with Detr->DeformableDetr + def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is + created using from_dict and kwargs e.g. `DeformableDetrImageProcessor.from_pretrained(checkpoint, size=600, + max_size=800)` + """ + image_processor_dict = image_processor_dict.copy() + if "max_size" in kwargs: + image_processor_dict["max_size"] = kwargs.pop("max_size") + if "pad_and_return_pixel_mask" in kwargs: + image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask") + return super().from_dict(image_processor_dict, **kwargs) + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare_annotation with DETR->DeformableDetr def prepare_annotation( self, diff --git a/src/transformers/models/detr/image_processing_detr.py b/src/transformers/models/detr/image_processing_detr.py index 17e5104a94..957360a96c 100644 --- a/src/transformers/models/detr/image_processing_detr.py +++ b/src/transformers/models/detr/image_processing_detr.py @@ -797,6 +797,20 @@ class DetrImageProcessor(BaseImageProcessor): self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD self.do_pad = do_pad + @classmethod + def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is + created using from_dict and kwargs e.g. `DetrImageProcessor.from_pretrained(checkpoint, size=600, + max_size=800)` + """ + image_processor_dict = image_processor_dict.copy() + if "max_size" in kwargs: + image_processor_dict["max_size"] = kwargs.pop("max_size") + if "pad_and_return_pixel_mask" in kwargs: + image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask") + return super().from_dict(image_processor_dict, **kwargs) + def prepare_annotation( self, image: np.ndarray, diff --git a/src/transformers/models/flava/image_processing_flava.py b/src/transformers/models/flava/image_processing_flava.py index 78bcfa3fa9..22e062306f 100644 --- a/src/transformers/models/flava/image_processing_flava.py +++ b/src/transformers/models/flava/image_processing_flava.py @@ -17,7 +17,7 @@ import math import random from functools import lru_cache -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np @@ -293,6 +293,19 @@ class FlavaImageProcessor(BaseImageProcessor): self.codebook_image_mean = codebook_image_mean if codebook_image_mean is not None else FLAVA_CODEBOOK_MEAN self.codebook_image_std = codebook_image_std if codebook_image_std is not None else FLAVA_CODEBOOK_STD + @classmethod + def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is + created using from_dict and kwargs e.g. `FlavaImageProcessor.from_pretrained(checkpoint, codebook_size=600)` + """ + image_processor_dict = image_processor_dict.copy() + if "codebook_size" in kwargs: + image_processor_dict["codebook_size"] = kwargs.pop("codebook_size") + if "codebook_crop_size" in kwargs: + image_processor_dict["codebook_crop_size"] = kwargs.pop("codebook_crop_size") + return super().from_dict(image_processor_dict, **kwargs) + @lru_cache() def masking_generator( self, diff --git a/src/transformers/models/maskformer/image_processing_maskformer.py b/src/transformers/models/maskformer/image_processing_maskformer.py index 50cef60700..1211cd6a82 100644 --- a/src/transformers/models/maskformer/image_processing_maskformer.py +++ b/src/transformers/models/maskformer/image_processing_maskformer.py @@ -400,7 +400,7 @@ class MaskFormerImageProcessor(BaseImageProcessor): if "size_divisibility" in kwargs: warnings.warn( "The `size_divisibility` argument is deprecated and will be removed in v4.27. Please use " - "`size_divisibility` instead.", + "`size_divisor` instead.", FutureWarning, ) size_divisor = kwargs.pop("size_divisibility") @@ -432,6 +432,19 @@ class MaskFormerImageProcessor(BaseImageProcessor): self.ignore_index = ignore_index self.reduce_labels = reduce_labels + @classmethod + def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is + created using from_dict and kwargs e.g. `MaskFormerImageProcessor.from_pretrained(checkpoint, max_size=800)` + """ + image_processor_dict = image_processor_dict.copy() + if "max_size" in kwargs: + image_processor_dict["max_size"] = kwargs.pop("max_size") + if "size_divisibility" in kwargs: + image_processor_dict["size_divisibility"] = kwargs.pop("size_divisibility") + return super().from_dict(image_processor_dict, **kwargs) + @property def size_divisibility(self): warnings.warn( diff --git a/src/transformers/models/segformer/image_processing_segformer.py b/src/transformers/models/segformer/image_processing_segformer.py index 4f6498b527..acc6026451 100644 --- a/src/transformers/models/segformer/image_processing_segformer.py +++ b/src/transformers/models/segformer/image_processing_segformer.py @@ -15,7 +15,7 @@ """Image processor class for Segformer.""" import warnings -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -119,6 +119,18 @@ class SegformerImageProcessor(BaseImageProcessor): self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD self.do_reduce_labels = do_reduce_labels + @classmethod + def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to make sure `reduce_labels` is updated if image processor + is created using from_dict and kwargs e.g. `SegformerImageProcessor.from_pretrained(checkpoint, + reduce_labels=True)` + """ + image_processor_dict = image_processor_dict.copy() + if "reduce_labels" in kwargs: + image_processor_dict["reduce_labels"] = kwargs.pop("reduce_labels") + return super().from_dict(image_processor_dict, **kwargs) + def resize( self, image: np.ndarray, diff --git a/src/transformers/models/vilt/image_processing_vilt.py b/src/transformers/models/vilt/image_processing_vilt.py index bb86967c5d..e4fbdec032 100644 --- a/src/transformers/models/vilt/image_processing_vilt.py +++ b/src/transformers/models/vilt/image_processing_vilt.py @@ -185,6 +185,18 @@ class ViltImageProcessor(BaseImageProcessor): self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD self.do_pad = do_pad + @classmethod + def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to make sure `reduce_labels` is updated if image processor + is created using from_dict and kwargs e.g. `ViltImageProcessor.from_pretrained(checkpoint, + pad_and_return_pixel_mask=False)` + """ + image_processor_dict = image_processor_dict.copy() + if "pad_and_return_pixel_mask" in kwargs: + image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask") + return super().from_dict(image_processor_dict, **kwargs) + def resize( self, image: np.ndarray, diff --git a/src/transformers/models/yolos/image_processing_yolos.py b/src/transformers/models/yolos/image_processing_yolos.py index 3e55412011..ff0cd23caa 100644 --- a/src/transformers/models/yolos/image_processing_yolos.py +++ b/src/transformers/models/yolos/image_processing_yolos.py @@ -725,6 +725,21 @@ class YolosImageProcessor(BaseImageProcessor): self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD self.do_pad = do_pad + @classmethod + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.from_dict with Detr->Yolos + def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is + created using from_dict and kwargs e.g. `YolosImageProcessor.from_pretrained(checkpoint, size=600, + max_size=800)` + """ + image_processor_dict = image_processor_dict.copy() + if "max_size" in kwargs: + image_processor_dict["max_size"] = kwargs.pop("max_size") + if "pad_and_return_pixel_mask" in kwargs: + image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask") + return super().from_dict(image_processor_dict, **kwargs) + # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.prepare_annotation def prepare_annotation( self, diff --git a/tests/models/beit/test_feature_extraction_beit.py b/tests/models/beit/test_feature_extraction_beit.py index de9e552393..545b4d79a9 100644 --- a/tests/models/beit/test_feature_extraction_beit.py +++ b/tests/models/beit/test_feature_extraction_beit.py @@ -125,6 +125,19 @@ class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC self.assertTrue(hasattr(feature_extractor, "image_mean")) self.assertTrue(hasattr(feature_extractor, "image_std")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"height": 20, "width": 20}) + self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18}) + self.assertEqual(feature_extractor.do_reduce_labels, False) + + feature_extractor = self.feature_extraction_class.from_dict( + self.feat_extract_dict, size=42, crop_size=84, reduce_labels=True + ) + self.assertEqual(feature_extractor.size, {"height": 42, "width": 42}) + self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84}) + self.assertEqual(feature_extractor.do_reduce_labels, True) + def test_batch_feature(self): pass diff --git a/tests/models/chinese_clip/test_feature_extraction_chinese_clip.py b/tests/models/chinese_clip/test_feature_extraction_chinese_clip.py index 613c904aff..616dfa3ffc 100644 --- a/tests/models/chinese_clip/test_feature_extraction_chinese_clip.py +++ b/tests/models/chinese_clip/test_feature_extraction_chinese_clip.py @@ -135,6 +135,15 @@ class ChineseCLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittes self.assertTrue(hasattr(feature_extractor, "image_std")) self.assertTrue(hasattr(feature_extractor, "do_convert_rgb")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"height": 224, "width": 224}) + self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18}) + + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84) + self.assertEqual(feature_extractor.size, {"shortest_edge": 42}) + self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84}) + def test_batch_feature(self): pass diff --git a/tests/models/clip/test_feature_extraction_clip.py b/tests/models/clip/test_feature_extraction_clip.py index e9c169cf51..8f29b63bbb 100644 --- a/tests/models/clip/test_feature_extraction_clip.py +++ b/tests/models/clip/test_feature_extraction_clip.py @@ -135,6 +135,15 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC self.assertTrue(hasattr(feature_extractor, "image_std")) self.assertTrue(hasattr(feature_extractor, "do_convert_rgb")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"shortest_edge": 20}) + self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18}) + + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84) + self.assertEqual(feature_extractor.size, {"shortest_edge": 42}) + self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84}) + def test_batch_feature(self): pass diff --git a/tests/models/conditional_detr/test_feature_extraction_conditional_detr.py b/tests/models/conditional_detr/test_feature_extraction_conditional_detr.py index 92ff4fe08f..4f3a6e21e0 100644 --- a/tests/models/conditional_detr/test_feature_extraction_conditional_detr.py +++ b/tests/models/conditional_detr/test_feature_extraction_conditional_detr.py @@ -133,6 +133,17 @@ class ConditionalDetrFeatureExtractionTest(FeatureExtractionSavingTestMixin, uni self.assertTrue(hasattr(feature_extractor, "do_resize")) self.assertTrue(hasattr(feature_extractor, "size")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"shortest_edge": 18, "longest_edge": 1333}) + self.assertEqual(feature_extractor.do_pad, True) + + feature_extractor = self.feature_extraction_class.from_dict( + self.feat_extract_dict, size=42, max_size=84, pad_and_return_pixel_mask=False + ) + self.assertEqual(feature_extractor.size, {"shortest_edge": 42, "longest_edge": 84}) + self.assertEqual(feature_extractor.do_pad, False) + def test_batch_feature(self): pass diff --git a/tests/models/convnext/test_feature_extraction_convnext.py b/tests/models/convnext/test_feature_extraction_convnext.py index 1419280f97..9777c3df6d 100644 --- a/tests/models/convnext/test_feature_extraction_convnext.py +++ b/tests/models/convnext/test_feature_extraction_convnext.py @@ -96,6 +96,13 @@ class ConvNextFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T self.assertTrue(hasattr(feature_extractor, "image_mean")) self.assertTrue(hasattr(feature_extractor, "image_std")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"shortest_edge": 20}) + + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42) + self.assertEqual(feature_extractor.size, {"shortest_edge": 42}) + def test_batch_feature(self): pass diff --git a/tests/models/deformable_detr/test_feature_extraction_deformable_detr.py b/tests/models/deformable_detr/test_feature_extraction_deformable_detr.py index e205f47a8d..aaafb7ff2f 100644 --- a/tests/models/deformable_detr/test_feature_extraction_deformable_detr.py +++ b/tests/models/deformable_detr/test_feature_extraction_deformable_detr.py @@ -135,6 +135,17 @@ class DeformableDetrFeatureExtractionTest(FeatureExtractionSavingTestMixin, unit self.assertTrue(hasattr(feature_extractor, "do_pad")) self.assertTrue(hasattr(feature_extractor, "size")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"shortest_edge": 18, "longest_edge": 1333}) + self.assertEqual(feature_extractor.do_pad, True) + + feature_extractor = self.feature_extraction_class.from_dict( + self.feat_extract_dict, size=42, max_size=84, pad_and_return_pixel_mask=False + ) + self.assertEqual(feature_extractor.size, {"shortest_edge": 42, "longest_edge": 84}) + self.assertEqual(feature_extractor.do_pad, False) + def test_batch_feature(self): pass diff --git a/tests/models/deit/test_feature_extraction_deit.py b/tests/models/deit/test_feature_extraction_deit.py index 32b107756e..f684008ccc 100644 --- a/tests/models/deit/test_feature_extraction_deit.py +++ b/tests/models/deit/test_feature_extraction_deit.py @@ -103,6 +103,15 @@ class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC self.assertTrue(hasattr(feature_extractor, "image_mean")) self.assertTrue(hasattr(feature_extractor, "image_std")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"height": 20, "width": 20}) + self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18}) + + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84) + self.assertEqual(feature_extractor.size, {"height": 42, "width": 42}) + self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84}) + def test_batch_feature(self): pass diff --git a/tests/models/detr/test_feature_extraction_detr.py b/tests/models/detr/test_feature_extraction_detr.py index 12e74c1203..6aafd62da4 100644 --- a/tests/models/detr/test_feature_extraction_detr.py +++ b/tests/models/detr/test_feature_extraction_detr.py @@ -136,6 +136,17 @@ class DetrFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "do_pad")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"shortest_edge": 18, "longest_edge": 1333}) + self.assertEqual(feature_extractor.do_pad, True) + + feature_extractor = self.feature_extraction_class.from_dict( + self.feat_extract_dict, size=42, max_size=84, pad_and_return_pixel_mask=False + ) + self.assertEqual(feature_extractor.size, {"shortest_edge": 42, "longest_edge": 84}) + self.assertEqual(feature_extractor.do_pad, False) + def test_batch_feature(self): pass diff --git a/tests/models/donut/test_feature_extraction_donut.py b/tests/models/donut/test_feature_extraction_donut.py index e97c32f6c9..4d0f88ac98 100644 --- a/tests/models/donut/test_feature_extraction_donut.py +++ b/tests/models/donut/test_feature_extraction_donut.py @@ -103,6 +103,17 @@ class DonutFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test self.assertTrue(hasattr(feature_extractor, "image_mean")) self.assertTrue(hasattr(feature_extractor, "image_std")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"height": 18, "width": 20}) + + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42) + self.assertEqual(feature_extractor.size, {"height": 42, "width": 42}) + + # Previous config had dimensions in (width, height) order + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=(42, 84)) + self.assertEqual(feature_extractor.size, {"height": 84, "width": 42}) + def test_batch_feature(self): pass diff --git a/tests/models/dpt/test_feature_extraction_dpt.py b/tests/models/dpt/test_feature_extraction_dpt.py index bcfec4b2aa..594b1451a7 100644 --- a/tests/models/dpt/test_feature_extraction_dpt.py +++ b/tests/models/dpt/test_feature_extraction_dpt.py @@ -92,6 +92,13 @@ class DPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa self.assertTrue(hasattr(feature_extractor, "do_resize")) self.assertTrue(hasattr(feature_extractor, "size")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"height": 18, "width": 18}) + + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42) + self.assertEqual(feature_extractor.size, {"height": 42, "width": 42}) + def test_call_pil(self): # Initialize feature_extractor feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) diff --git a/tests/models/flava/test_feature_extraction_flava.py b/tests/models/flava/test_feature_extraction_flava.py index bb771de36f..ba6379e6b3 100644 --- a/tests/models/flava/test_feature_extraction_flava.py +++ b/tests/models/flava/test_feature_extraction_flava.py @@ -193,6 +193,21 @@ class FlavaFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test self.assertTrue(hasattr(feature_extractor, "codebook_image_mean")) self.assertTrue(hasattr(feature_extractor, "codebook_image_std")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"height": 224, "width": 224}) + self.assertEqual(feature_extractor.crop_size, {"height": 224, "width": 224}) + self.assertEqual(feature_extractor.codebook_size, {"height": 112, "width": 112}) + self.assertEqual(feature_extractor.codebook_crop_size, {"height": 112, "width": 112}) + + feature_extractor = self.feature_extraction_class.from_dict( + self.feat_extract_dict, size=42, crop_size=84, codebook_size=33, codebook_crop_size=66 + ) + self.assertEqual(feature_extractor.size, {"height": 42, "width": 42}) + self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84}) + self.assertEqual(feature_extractor.codebook_size, {"height": 33, "width": 33}) + self.assertEqual(feature_extractor.codebook_crop_size, {"height": 66, "width": 66}) + def test_batch_feature(self): pass diff --git a/tests/models/imagegpt/test_feature_extraction_imagegpt.py b/tests/models/imagegpt/test_feature_extraction_imagegpt.py index 0dd614840b..465a6015a3 100644 --- a/tests/models/imagegpt/test_feature_extraction_imagegpt.py +++ b/tests/models/imagegpt/test_feature_extraction_imagegpt.py @@ -96,6 +96,13 @@ class ImageGPTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "do_normalize")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"height": 18, "width": 18}) + + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42) + self.assertEqual(feature_extractor.size, {"height": 42, "width": 42}) + def test_feat_extract_to_json_string(self): feat_extract = self.feature_extraction_class(**self.feat_extract_dict) obj = json.loads(feat_extract.to_json_string()) diff --git a/tests/models/layoutlmv2/test_feature_extraction_layoutlmv2.py b/tests/models/layoutlmv2/test_feature_extraction_layoutlmv2.py index 0a3528e16c..c26eaac16e 100644 --- a/tests/models/layoutlmv2/test_feature_extraction_layoutlmv2.py +++ b/tests/models/layoutlmv2/test_feature_extraction_layoutlmv2.py @@ -80,6 +80,13 @@ class LayoutLMv2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "apply_ocr")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"height": 18, "width": 18}) + + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42) + self.assertEqual(feature_extractor.size, {"height": 42, "width": 42}) + def test_batch_feature(self): pass diff --git a/tests/models/layoutlmv3/test_feature_extraction_layoutlmv3.py b/tests/models/layoutlmv3/test_feature_extraction_layoutlmv3.py index 68a32e6e8f..c8eb976bf5 100644 --- a/tests/models/layoutlmv3/test_feature_extraction_layoutlmv3.py +++ b/tests/models/layoutlmv3/test_feature_extraction_layoutlmv3.py @@ -80,6 +80,13 @@ class LayoutLMv3FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "apply_ocr")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"height": 18, "width": 18}) + + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42) + self.assertEqual(feature_extractor.size, {"height": 42, "width": 42}) + def test_batch_feature(self): pass diff --git a/tests/models/levit/test_feature_extraction_levit.py b/tests/models/levit/test_feature_extraction_levit.py index 138542d85d..2b1472d9b6 100644 --- a/tests/models/levit/test_feature_extraction_levit.py +++ b/tests/models/levit/test_feature_extraction_levit.py @@ -100,6 +100,15 @@ class LevitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test self.assertTrue(hasattr(feature_extractor, "do_center_crop")) self.assertTrue(hasattr(feature_extractor, "size")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"shortest_edge": 18}) + self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18}) + + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84) + self.assertEqual(feature_extractor.size, {"shortest_edge": 42}) + self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84}) + def test_batch_feature(self): pass diff --git a/tests/models/maskformer/test_feature_extraction_maskformer.py b/tests/models/maskformer/test_feature_extraction_maskformer.py index ca2f504c06..624a48ac8e 100644 --- a/tests/models/maskformer/test_feature_extraction_maskformer.py +++ b/tests/models/maskformer/test_feature_extraction_maskformer.py @@ -152,6 +152,17 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest self.assertTrue(hasattr(feature_extractor, "ignore_index")) self.assertTrue(hasattr(feature_extractor, "num_labels")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"shortest_edge": 32, "longest_edge": 1333}) + self.assertEqual(feature_extractor.size_divisor, 0) + + feature_extractor = self.feature_extraction_class.from_dict( + self.feat_extract_dict, size=42, max_size=84, size_divisibility=8 + ) + self.assertEqual(feature_extractor.size, {"shortest_edge": 42, "longest_edge": 84}) + self.assertEqual(feature_extractor.size_divisor, 8) + def test_batch_feature(self): pass diff --git a/tests/models/mobilenet_v1/test_feature_extraction_mobilenet_v1.py b/tests/models/mobilenet_v1/test_feature_extraction_mobilenet_v1.py index 6ddbd4c126..270d38d5b8 100644 --- a/tests/models/mobilenet_v1/test_feature_extraction_mobilenet_v1.py +++ b/tests/models/mobilenet_v1/test_feature_extraction_mobilenet_v1.py @@ -89,6 +89,15 @@ class MobileNetV1FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittes self.assertTrue(hasattr(feature_extractor, "do_center_crop")) self.assertTrue(hasattr(feature_extractor, "center_crop")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"shortest_edge": 20}) + self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18}) + + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84) + self.assertEqual(feature_extractor.size, {"shortest_edge": 42}) + self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84}) + def test_batch_feature(self): pass diff --git a/tests/models/mobilenet_v2/test_feature_extraction_mobilenet_v2.py b/tests/models/mobilenet_v2/test_feature_extraction_mobilenet_v2.py index 1c88492e4c..3cb4eea218 100644 --- a/tests/models/mobilenet_v2/test_feature_extraction_mobilenet_v2.py +++ b/tests/models/mobilenet_v2/test_feature_extraction_mobilenet_v2.py @@ -87,7 +87,16 @@ class MobileNetV2FeatureExtractionTest(FeatureExtractionSavingTestMixin, unittes self.assertTrue(hasattr(feature_extractor, "do_resize")) self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "do_center_crop")) - self.assertTrue(hasattr(feature_extractor, "center_crop")) + self.assertTrue(hasattr(feature_extractor, "crop_size")) + + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"shortest_edge": 20}) + self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18}) + + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84) + self.assertEqual(feature_extractor.size, {"shortest_edge": 42}) + self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84}) def test_batch_feature(self): pass diff --git a/tests/models/mobilevit/test_feature_extraction_mobilevit.py b/tests/models/mobilevit/test_feature_extraction_mobilevit.py index 1a2f52d0da..468c4689e4 100644 --- a/tests/models/mobilevit/test_feature_extraction_mobilevit.py +++ b/tests/models/mobilevit/test_feature_extraction_mobilevit.py @@ -93,6 +93,15 @@ class MobileViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest. self.assertTrue(hasattr(feature_extractor, "center_crop")) self.assertTrue(hasattr(feature_extractor, "do_flip_channel_order")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"shortest_edge": 20}) + self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18}) + + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84) + self.assertEqual(feature_extractor.size, {"shortest_edge": 42}) + self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84}) + def test_batch_feature(self): pass diff --git a/tests/models/owlvit/test_feature_extraction_owlvit.py b/tests/models/owlvit/test_feature_extraction_owlvit.py index 0435ea91b9..fe259b1169 100644 --- a/tests/models/owlvit/test_feature_extraction_owlvit.py +++ b/tests/models/owlvit/test_feature_extraction_owlvit.py @@ -103,6 +103,15 @@ class OwlViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Tes self.assertTrue(hasattr(feature_extractor, "image_std")) self.assertTrue(hasattr(feature_extractor, "do_convert_rgb")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"height": 18, "width": 18}) + self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18}) + + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84) + self.assertEqual(feature_extractor.size, {"height": 42, "width": 42}) + self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84}) + def test_call_pil(self): # Initialize feature_extractor feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) diff --git a/tests/models/poolformer/test_feature_extraction_poolformer.py b/tests/models/poolformer/test_feature_extraction_poolformer.py index 41599989b1..b1fffe8a5a 100644 --- a/tests/models/poolformer/test_feature_extraction_poolformer.py +++ b/tests/models/poolformer/test_feature_extraction_poolformer.py @@ -97,6 +97,15 @@ class PoolFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest self.assertTrue(hasattr(feature_extractor, "image_mean")) self.assertTrue(hasattr(feature_extractor, "image_std")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"shortest_edge": 30}) + self.assertEqual(feature_extractor.crop_size, {"height": 30, "width": 30}) + + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84) + self.assertEqual(feature_extractor.size, {"shortest_edge": 42}) + self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84}) + def test_batch_feature(self): pass diff --git a/tests/models/segformer/test_feature_extraction_segformer.py b/tests/models/segformer/test_feature_extraction_segformer.py index b3ba44862b..4257b27b81 100644 --- a/tests/models/segformer/test_feature_extraction_segformer.py +++ b/tests/models/segformer/test_feature_extraction_segformer.py @@ -115,6 +115,17 @@ class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest. self.assertTrue(hasattr(feature_extractor, "image_std")) self.assertTrue(hasattr(feature_extractor, "do_reduce_labels")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"height": 30, "width": 30}) + self.assertEqual(feature_extractor.do_reduce_labels, False) + + feature_extractor = self.feature_extraction_class.from_dict( + self.feat_extract_dict, size=42, reduce_labels=True + ) + self.assertEqual(feature_extractor.size, {"height": 42, "width": 42}) + self.assertEqual(feature_extractor.do_reduce_labels, True) + def test_batch_feature(self): pass diff --git a/tests/models/videomae/test_feature_extraction_videomae.py b/tests/models/videomae/test_feature_extraction_videomae.py index eebdbb7cc3..f792a9be84 100644 --- a/tests/models/videomae/test_feature_extraction_videomae.py +++ b/tests/models/videomae/test_feature_extraction_videomae.py @@ -100,6 +100,15 @@ class VideoMAEFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.T self.assertTrue(hasattr(feature_extractor, "do_center_crop")) self.assertTrue(hasattr(feature_extractor, "size")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"shortest_edge": 18}) + self.assertEqual(feature_extractor.crop_size, {"height": 18, "width": 18}) + + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42, crop_size=84) + self.assertEqual(feature_extractor.size, {"shortest_edge": 42}) + self.assertEqual(feature_extractor.crop_size, {"height": 84, "width": 84}) + def test_batch_feature(self): pass diff --git a/tests/models/vilt/test_feature_extraction_vilt.py b/tests/models/vilt/test_feature_extraction_vilt.py index d2e0d2e803..5816eacf83 100644 --- a/tests/models/vilt/test_feature_extraction_vilt.py +++ b/tests/models/vilt/test_feature_extraction_vilt.py @@ -136,6 +136,13 @@ class ViltFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "size_divisor")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"shortest_edge": 30}) + + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42) + self.assertEqual(feature_extractor.size, {"shortest_edge": 42}) + def test_batch_feature(self): pass diff --git a/tests/models/vit/test_feature_extraction_vit.py b/tests/models/vit/test_feature_extraction_vit.py index e33b7361ab..f419742509 100644 --- a/tests/models/vit/test_feature_extraction_vit.py +++ b/tests/models/vit/test_feature_extraction_vit.py @@ -92,6 +92,13 @@ class ViTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCa self.assertTrue(hasattr(feature_extractor, "do_resize")) self.assertTrue(hasattr(feature_extractor, "size")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"height": 18, "width": 18}) + + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict, size=42) + self.assertEqual(feature_extractor.size, {"height": 42, "width": 42}) + def test_batch_feature(self): pass diff --git a/tests/models/yolos/test_feature_extraction_yolos.py b/tests/models/yolos/test_feature_extraction_yolos.py index 162146e6b6..2c1571d7f7 100644 --- a/tests/models/yolos/test_feature_extraction_yolos.py +++ b/tests/models/yolos/test_feature_extraction_yolos.py @@ -133,6 +133,17 @@ class YolosFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.Test self.assertTrue(hasattr(feature_extractor, "do_resize")) self.assertTrue(hasattr(feature_extractor, "size")) + def test_feat_extract_from_dict_with_kwargs(self): + feature_extractor = self.feature_extraction_class.from_dict(self.feat_extract_dict) + self.assertEqual(feature_extractor.size, {"shortest_edge": 18, "longest_edge": 1333}) + self.assertEqual(feature_extractor.do_pad, True) + + feature_extractor = self.feature_extraction_class.from_dict( + self.feat_extract_dict, size=42, max_size=84, pad_and_return_pixel_mask=False + ) + self.assertEqual(feature_extractor.size, {"shortest_edge": 42, "longest_edge": 84}) + self.assertEqual(feature_extractor.do_pad, False) + def test_batch_feature(self): pass