Add push to hub to feature extractor (#15632)
* Add push to hub to feature extractor * Quality * Clean up
This commit is contained in:
@@ -30,6 +30,7 @@ from .dynamic_module_utils import custom_object_save
|
||||
from .file_utils import (
|
||||
FEATURE_EXTRACTOR_NAME,
|
||||
EntryNotFoundError,
|
||||
PushToHubMixin,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
TensorType,
|
||||
@@ -37,6 +38,7 @@ from .file_utils import (
|
||||
_is_numpy,
|
||||
_is_torch_device,
|
||||
cached_path,
|
||||
copy_func,
|
||||
hf_bucket_url,
|
||||
is_flax_available,
|
||||
is_offline_mode,
|
||||
@@ -200,7 +202,7 @@ class BatchFeature(UserDict):
|
||||
return self
|
||||
|
||||
|
||||
class FeatureExtractionMixin:
|
||||
class FeatureExtractionMixin(PushToHubMixin):
|
||||
"""
|
||||
This is a feature extraction mixin used to provide saving/loading functionality for sequential and image feature
|
||||
extractors.
|
||||
@@ -308,7 +310,7 @@ class FeatureExtractionMixin:
|
||||
|
||||
return cls.from_dict(feature_extractor_dict, **kwargs)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||
"""
|
||||
Save a feature_extractor object to the directory `save_directory`, so that it can be re-loaded using the
|
||||
[`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] class method.
|
||||
@@ -316,10 +318,27 @@ class FeatureExtractionMixin:
|
||||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory where the feature extractor JSON file will be saved (will be created if it does not exist).
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your feature extractor to the Hugging Face model hub after saving it.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
|
||||
which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
|
||||
folder. Pass along `temp_dir=True` to use a temporary directory instead.
|
||||
|
||||
</Tip>
|
||||
|
||||
kwargs:
|
||||
Additional key word arguments passed along to the [`~file_utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
repo = self._create_or_get_repo(save_directory, **kwargs)
|
||||
|
||||
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
|
||||
# loaded from the Hub.
|
||||
if self._auto_class is not None:
|
||||
@@ -330,7 +349,11 @@ class FeatureExtractionMixin:
|
||||
output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME)
|
||||
|
||||
self.to_json_file(output_feature_extractor_file)
|
||||
logger.info(f"Configuration saved in {output_feature_extractor_file}")
|
||||
logger.info(f"Feature extractor saved in {output_feature_extractor_file}")
|
||||
|
||||
if push_to_hub:
|
||||
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||
logger.info(f"Feature extractor pushed to the hub in this commit: {url}")
|
||||
|
||||
@classmethod
|
||||
def get_feature_extractor_dict(
|
||||
@@ -574,3 +597,9 @@ class FeatureExtractionMixin:
|
||||
raise ValueError(f"{auto_class} is not a valid auto class.")
|
||||
|
||||
cls._auto_class = auto_class
|
||||
|
||||
|
||||
FeatureExtractionMixin.push_to_hub = copy_func(FeatureExtractionMixin.push_to_hub)
|
||||
FeatureExtractionMixin.push_to_hub.__doc__ = FeatureExtractionMixin.push_to_hub.__doc__.format(
|
||||
object="feature extractor", object_class="AutoFeatureExtractor", object_files="feature extractor file"
|
||||
)
|
||||
|
||||
@@ -23,7 +23,7 @@ from pathlib import Path
|
||||
|
||||
from huggingface_hub import Repository, delete_repo, login
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import AutoFeatureExtractor
|
||||
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
|
||||
from transformers.file_utils import is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import PASS, USER, is_staging_test
|
||||
|
||||
@@ -40,7 +40,6 @@ if is_torch_available():
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
|
||||
|
||||
|
||||
@@ -124,11 +123,47 @@ class ConfigPushToHubTester(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
delete_repo(token=cls._token, name="test-feature-extractor")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, name="test-feature-extractor-org", organization="valid_org")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, name="test-dynamic-feature-extractor")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
def test_push_to_hub(self):
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
feature_extractor.save_pretrained(
|
||||
os.path.join(tmp_dir, "test-feature-extractor"), push_to_hub=True, use_auth_token=self._token
|
||||
)
|
||||
|
||||
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor")
|
||||
for k, v in feature_extractor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
feature_extractor.save_pretrained(
|
||||
os.path.join(tmp_dir, "test-feature-extractor-org"),
|
||||
push_to_hub=True,
|
||||
use_auth_token=self._token,
|
||||
organization="valid_org",
|
||||
)
|
||||
|
||||
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor-org")
|
||||
for k, v in feature_extractor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||
|
||||
def test_push_to_hub_dynamic_feature_extractor(self):
|
||||
CustomFeatureExtractor.register_for_auto_class()
|
||||
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||
|
||||
Reference in New Issue
Block a user