Make ImageProcessorMixin compatible with subfolder kwarg (#21725)

* Add subfolder support

* Add kwarg docstring

* formatting fix

* Add test
This commit is contained in:
Naga Sai Abhinay
2023-02-23 13:58:18 +05:30
committed by GitHub
parent 064f374874
commit 448e050b0d
2 changed files with 19 additions and 0 deletions

View File

@@ -128,6 +128,9 @@ class ImageProcessingMixin(PushToHubMixin):
functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of
`kwargs` which has not been used to update `image_processor` and is otherwise ignored. `kwargs` which has not been used to update `image_processor` and is otherwise ignored.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
kwargs (`Dict[str, Any]`, *optional*): kwargs (`Dict[str, Any]`, *optional*):
The values in kwargs of any keys which are image processor attributes will be used to override the The values in kwargs of any keys which are image processor attributes will be used to override the
loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is
@@ -221,6 +224,9 @@ class ImageProcessingMixin(PushToHubMixin):
Parameters: Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`): pretrained_model_name_or_path (`str` or `os.PathLike`):
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
Returns: Returns:
`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the image processor object. `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the image processor object.
@@ -232,6 +238,7 @@ class ImageProcessingMixin(PushToHubMixin):
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", "")
from_pipeline = kwargs.pop("_from_pipeline", None) from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False) from_auto_class = kwargs.pop("_from_auto", False)
@@ -269,6 +276,7 @@ class ImageProcessingMixin(PushToHubMixin):
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
user_agent=user_agent, user_agent=user_agent,
revision=revision, revision=revision,
subfolder=subfolder,
) )
except EnvironmentError: except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to

View File

@@ -311,3 +311,14 @@ class ImageProcessorPushToHubTester(unittest.TestCase):
) )
# Can't make an isinstance check because the new_image_processor is from the CustomImageProcessor class of a dynamic module # Can't make an isinstance check because the new_image_processor is from the CustomImageProcessor class of a dynamic module
self.assertEqual(new_image_processor.__class__.__name__, "CustomImageProcessor") self.assertEqual(new_image_processor.__class__.__name__, "CustomImageProcessor")
def test_image_processor_from_pretrained_subfolder(self):
with self.assertRaises(OSError):
# config is in subfolder, the following should not work without specifying the subfolder
_ = AutoImageProcessor.from_pretrained("hf-internal-testing/stable-diffusion-all-variants")
config = AutoImageProcessor.from_pretrained(
"hf-internal-testing/stable-diffusion-all-variants", subfolder="feature_extractor"
)
self.assertIsNotNone(config)