diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index e535c3dbde..453d73e360 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -30,6 +30,7 @@ from .dynamic_module_utils import custom_object_save from .file_utils import ( FEATURE_EXTRACTOR_NAME, EntryNotFoundError, + PushToHubMixin, RepositoryNotFoundError, RevisionNotFoundError, TensorType, @@ -37,6 +38,7 @@ from .file_utils import ( _is_numpy, _is_torch_device, cached_path, + copy_func, hf_bucket_url, is_flax_available, is_offline_mode, @@ -200,7 +202,7 @@ class BatchFeature(UserDict): return self -class FeatureExtractionMixin: +class FeatureExtractionMixin(PushToHubMixin): """ This is a feature extraction mixin used to provide saving/loading functionality for sequential and image feature extractors. @@ -308,7 +310,7 @@ class FeatureExtractionMixin: return cls.from_dict(feature_extractor_dict, **kwargs) - def save_pretrained(self, save_directory: Union[str, os.PathLike]): + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): """ Save a feature_extractor object to the directory `save_directory`, so that it can be re-loaded using the [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] class method. @@ -316,10 +318,27 @@ class FeatureExtractionMixin: Args: save_directory (`str` or `os.PathLike`): Directory where the feature extractor JSON file will be saved (will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your feature extractor to the Hugging Face model hub after saving it. + + + + Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`, + which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing + folder. Pass along `temp_dir=True` to use a temporary directory instead. + + + + kwargs: + Additional key word arguments passed along to the [`~file_utils.PushToHubMixin.push_to_hub`] method. """ if os.path.isfile(save_directory): raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo = self._create_or_get_repo(save_directory, **kwargs) + # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be # loaded from the Hub. if self._auto_class is not None: @@ -330,7 +349,11 @@ class FeatureExtractionMixin: output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME) self.to_json_file(output_feature_extractor_file) - logger.info(f"Configuration saved in {output_feature_extractor_file}") + logger.info(f"Feature extractor saved in {output_feature_extractor_file}") + + if push_to_hub: + url = self._push_to_hub(repo, commit_message=commit_message) + logger.info(f"Feature extractor pushed to the hub in this commit: {url}") @classmethod def get_feature_extractor_dict( @@ -574,3 +597,9 @@ class FeatureExtractionMixin: raise ValueError(f"{auto_class} is not a valid auto class.") cls._auto_class = auto_class + + +FeatureExtractionMixin.push_to_hub = copy_func(FeatureExtractionMixin.push_to_hub) +FeatureExtractionMixin.push_to_hub.__doc__ = FeatureExtractionMixin.push_to_hub.__doc__.format( + object="feature extractor", object_class="AutoFeatureExtractor", object_files="feature extractor file" +) diff --git a/tests/test_feature_extraction_common.py b/tests/test_feature_extraction_common.py index 931ee2444e..861617ec90 100644 --- a/tests/test_feature_extraction_common.py +++ b/tests/test_feature_extraction_common.py @@ -23,7 +23,7 @@ from pathlib import Path from huggingface_hub import Repository, delete_repo, login from requests.exceptions import HTTPError -from transformers import AutoFeatureExtractor +from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor from transformers.file_utils import is_torch_available, is_vision_available from transformers.testing_utils import PASS, USER, is_staging_test @@ -40,7 +40,6 @@ if is_torch_available(): if is_vision_available(): from PIL import Image - SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures") @@ -124,11 +123,47 @@ class ConfigPushToHubTester(unittest.TestCase): @classmethod def tearDownClass(cls): + try: + delete_repo(token=cls._token, name="test-feature-extractor") + except HTTPError: + pass + + try: + delete_repo(token=cls._token, name="test-feature-extractor-org", organization="valid_org") + except HTTPError: + pass + try: delete_repo(token=cls._token, name="test-dynamic-feature-extractor") except HTTPError: pass + def test_push_to_hub(self): + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR) + with tempfile.TemporaryDirectory() as tmp_dir: + feature_extractor.save_pretrained( + os.path.join(tmp_dir, "test-feature-extractor"), push_to_hub=True, use_auth_token=self._token + ) + + new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor") + for k, v in feature_extractor.__dict__.items(): + self.assertEqual(v, getattr(new_feature_extractor, k)) + + def test_push_to_hub_in_organization(self): + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR) + + with tempfile.TemporaryDirectory() as tmp_dir: + feature_extractor.save_pretrained( + os.path.join(tmp_dir, "test-feature-extractor-org"), + push_to_hub=True, + use_auth_token=self._token, + organization="valid_org", + ) + + new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor-org") + for k, v in feature_extractor.__dict__.items(): + self.assertEqual(v, getattr(new_feature_extractor, k)) + def test_push_to_hub_dynamic_feature_extractor(self): CustomFeatureExtractor.register_for_auto_class() feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)