From 448e050b0d6261603137bd9fa548827a8392e348 Mon Sep 17 00:00:00 2001 From: Naga Sai Abhinay Date: Thu, 23 Feb 2023 13:58:18 +0530 Subject: [PATCH] Make ImageProcessorMixin compatible with subfolder kwarg (#21725) * Add subfolder support * Add kwarg docstring * formatting fix * Add test --- src/transformers/image_processing_utils.py | 8 ++++++++ tests/test_image_processing_common.py | 11 +++++++++++ 2 files changed, 19 insertions(+) diff --git a/src/transformers/image_processing_utils.py b/src/transformers/image_processing_utils.py index feff54a3ff..cce3d475ba 100644 --- a/src/transformers/image_processing_utils.py +++ b/src/transformers/image_processing_utils.py @@ -128,6 +128,9 @@ class ImageProcessingMixin(PushToHubMixin): functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of `kwargs` which has not been used to update `image_processor` and is otherwise ignored. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. kwargs (`Dict[str, Any]`, *optional*): The values in kwargs of any keys which are image processor attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is @@ -221,6 +224,9 @@ class ImageProcessingMixin(PushToHubMixin): Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`): The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. Returns: `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the image processor object. @@ -232,6 +238,7 @@ class ImageProcessingMixin(PushToHubMixin): use_auth_token = kwargs.pop("use_auth_token", None) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", "") from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) @@ -269,6 +276,7 @@ class ImageProcessingMixin(PushToHubMixin): use_auth_token=use_auth_token, user_agent=user_agent, revision=revision, + subfolder=subfolder, ) except EnvironmentError: # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to diff --git a/tests/test_image_processing_common.py b/tests/test_image_processing_common.py index e18f8bf60f..32be6e0e63 100644 --- a/tests/test_image_processing_common.py +++ b/tests/test_image_processing_common.py @@ -311,3 +311,14 @@ class ImageProcessorPushToHubTester(unittest.TestCase): ) # Can't make an isinstance check because the new_image_processor is from the CustomImageProcessor class of a dynamic module self.assertEqual(new_image_processor.__class__.__name__, "CustomImageProcessor") + + def test_image_processor_from_pretrained_subfolder(self): + with self.assertRaises(OSError): + # config is in subfolder, the following should not work without specifying the subfolder + _ = AutoImageProcessor.from_pretrained("hf-internal-testing/stable-diffusion-all-variants") + + config = AutoImageProcessor.from_pretrained( + "hf-internal-testing/stable-diffusion-all-variants", subfolder="feature_extractor" + ) + + self.assertIsNotNone(config)