From 6bc6797e04811176f4244a42c86f8a65a1e1c455 Mon Sep 17 00:00:00 2001 From: Heng Kuan Wee Date: Wed, 11 May 2022 20:09:54 +0800 Subject: [PATCH] Convert image to rgb for clip model (#17101) Co-authored-by: kuanwee.heng --- .../models/clip/feature_extraction_clip.py | 22 ++++++- .../clip/test_feature_extraction_clip.py | 65 +++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/clip/feature_extraction_clip.py b/src/transformers/models/clip/feature_extraction_clip.py index 7614d05afd..d5b66961bf 100644 --- a/src/transformers/models/clip/feature_extraction_clip.py +++ b/src/transformers/models/clip/feature_extraction_clip.py @@ -54,6 +54,8 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): The sequence of means for each channel, to be used when normalizing images. image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`): The sequence of standard deviations for each channel, to be used when normalizing images. + convert_rgb (`bool`, defaults to `True`): + Whether or not to convert `PIL.Image.Image` into `RGB` format """ model_input_names = ["pixel_values"] @@ -68,6 +70,7 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): do_normalize=True, image_mean=None, image_std=None, + do_convert_rgb=True, **kwargs ): super().__init__(**kwargs) @@ -79,6 +82,7 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): self.do_normalize = do_normalize self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073] self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711] + self.do_convert_rgb = do_convert_rgb def __call__( self, @@ -141,7 +145,9 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): if not is_batched: images = [images] - # transformations (resizing + center cropping + normalization) + # transformations (convert rgb + resizing + center cropping + normalization) + if self.do_convert_rgb: + images = [self.convert_rgb(image) for image in images] if self.do_resize and self.size is not None and self.resample is not None: images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images] if self.do_center_crop and self.crop_size is not None: @@ -155,6 +161,20 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): return encoded_inputs + def convert_rgb(self, image): + """ + Converts `image` to RGB format. Note that this will trigger a conversion of `image` to a PIL Image. + + Args: + image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): + The image to convert. + """ + self._ensure_format_supported(image) + if not isinstance(image, Image.Image): + return image + + return image.convert("RGB") + def center_crop(self, image, size): """ Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the diff --git a/tests/models/clip/test_feature_extraction_clip.py b/tests/models/clip/test_feature_extraction_clip.py index a3f0817ea0..8f36a65ae2 100644 --- a/tests/models/clip/test_feature_extraction_clip.py +++ b/tests/models/clip/test_feature_extraction_clip.py @@ -49,6 +49,7 @@ class CLIPFeatureExtractionTester(unittest.TestCase): do_normalize=True, image_mean=[0.48145466, 0.4578275, 0.40821073], image_std=[0.26862954, 0.26130258, 0.27577711], + do_convert_rgb=True, ): self.parent = parent self.batch_size = batch_size @@ -63,6 +64,7 @@ class CLIPFeatureExtractionTester(unittest.TestCase): self.do_normalize = do_normalize self.image_mean = image_mean self.image_std = image_std + self.do_convert_rgb = do_convert_rgb def prepare_feat_extract_dict(self): return { @@ -73,6 +75,7 @@ class CLIPFeatureExtractionTester(unittest.TestCase): "do_normalize": self.do_normalize, "image_mean": self.image_mean, "image_std": self.image_std, + "do_convert_rgb": self.do_convert_rgb, } def prepare_inputs(self, equal_resolution=False, numpify=False, torchify=False): @@ -128,6 +131,7 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC self.assertTrue(hasattr(feature_extractor, "do_normalize")) self.assertTrue(hasattr(feature_extractor, "image_mean")) self.assertTrue(hasattr(feature_extractor, "image_std")) + self.assertTrue(hasattr(feature_extractor, "do_convert_rgb")) def test_batch_feature(self): pass @@ -227,3 +231,64 @@ class CLIPFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestC self.feature_extract_tester.crop_size, ), ) + + +@require_torch +@require_vision +class CLIPFeatureExtractionTestFourChannels(FeatureExtractionSavingTestMixin, unittest.TestCase): + + feature_extraction_class = CLIPFeatureExtractor if is_vision_available() else None + + def setUp(self): + self.feature_extract_tester = CLIPFeatureExtractionTester(self, num_channels=4) + self.expected_encoded_image_num_channels = 3 + + @property + def feat_extract_dict(self): + return self.feature_extract_tester.prepare_feat_extract_dict() + + def test_feat_extract_properties(self): + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + self.assertTrue(hasattr(feature_extractor, "do_resize")) + self.assertTrue(hasattr(feature_extractor, "size")) + self.assertTrue(hasattr(feature_extractor, "do_center_crop")) + self.assertTrue(hasattr(feature_extractor, "center_crop")) + self.assertTrue(hasattr(feature_extractor, "do_normalize")) + self.assertTrue(hasattr(feature_extractor, "image_mean")) + self.assertTrue(hasattr(feature_extractor, "image_std")) + self.assertTrue(hasattr(feature_extractor, "do_convert_rgb")) + + def test_batch_feature(self): + pass + + def test_call_pil_four_channels(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random PIL images + image_inputs = self.feature_extract_tester.prepare_inputs(equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.expected_encoded_image_num_channels, + self.feature_extract_tester.crop_size, + self.feature_extract_tester.crop_size, + ), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.expected_encoded_image_num_channels, + self.feature_extract_tester.crop_size, + self.feature_extract_tester.crop_size, + ), + )