Add push to hub to feature extractor (#15632)
* Add push to hub to feature extractor * Quality * Clean up
This commit is contained in:
@@ -30,6 +30,7 @@ from .dynamic_module_utils import custom_object_save
|
|||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
FEATURE_EXTRACTOR_NAME,
|
FEATURE_EXTRACTOR_NAME,
|
||||||
EntryNotFoundError,
|
EntryNotFoundError,
|
||||||
|
PushToHubMixin,
|
||||||
RepositoryNotFoundError,
|
RepositoryNotFoundError,
|
||||||
RevisionNotFoundError,
|
RevisionNotFoundError,
|
||||||
TensorType,
|
TensorType,
|
||||||
@@ -37,6 +38,7 @@ from .file_utils import (
|
|||||||
_is_numpy,
|
_is_numpy,
|
||||||
_is_torch_device,
|
_is_torch_device,
|
||||||
cached_path,
|
cached_path,
|
||||||
|
copy_func,
|
||||||
hf_bucket_url,
|
hf_bucket_url,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
@@ -200,7 +202,7 @@ class BatchFeature(UserDict):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class FeatureExtractionMixin:
|
class FeatureExtractionMixin(PushToHubMixin):
|
||||||
"""
|
"""
|
||||||
This is a feature extraction mixin used to provide saving/loading functionality for sequential and image feature
|
This is a feature extraction mixin used to provide saving/loading functionality for sequential and image feature
|
||||||
extractors.
|
extractors.
|
||||||
@@ -308,7 +310,7 @@ class FeatureExtractionMixin:
|
|||||||
|
|
||||||
return cls.from_dict(feature_extractor_dict, **kwargs)
|
return cls.from_dict(feature_extractor_dict, **kwargs)
|
||||||
|
|
||||||
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||||
"""
|
"""
|
||||||
Save a feature_extractor object to the directory `save_directory`, so that it can be re-loaded using the
|
Save a feature_extractor object to the directory `save_directory`, so that it can be re-loaded using the
|
||||||
[`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] class method.
|
[`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] class method.
|
||||||
@@ -316,10 +318,27 @@ class FeatureExtractionMixin:
|
|||||||
Args:
|
Args:
|
||||||
save_directory (`str` or `os.PathLike`):
|
save_directory (`str` or `os.PathLike`):
|
||||||
Directory where the feature extractor JSON file will be saved (will be created if it does not exist).
|
Directory where the feature extractor JSON file will be saved (will be created if it does not exist).
|
||||||
|
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to push your feature extractor to the Hugging Face model hub after saving it.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
|
||||||
|
which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
|
||||||
|
folder. Pass along `temp_dir=True` to use a temporary directory instead.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
kwargs:
|
||||||
|
Additional key word arguments passed along to the [`~file_utils.PushToHubMixin.push_to_hub`] method.
|
||||||
"""
|
"""
|
||||||
if os.path.isfile(save_directory):
|
if os.path.isfile(save_directory):
|
||||||
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
commit_message = kwargs.pop("commit_message", None)
|
||||||
|
repo = self._create_or_get_repo(save_directory, **kwargs)
|
||||||
|
|
||||||
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
|
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
|
||||||
# loaded from the Hub.
|
# loaded from the Hub.
|
||||||
if self._auto_class is not None:
|
if self._auto_class is not None:
|
||||||
@@ -330,7 +349,11 @@ class FeatureExtractionMixin:
|
|||||||
output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME)
|
output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME)
|
||||||
|
|
||||||
self.to_json_file(output_feature_extractor_file)
|
self.to_json_file(output_feature_extractor_file)
|
||||||
logger.info(f"Configuration saved in {output_feature_extractor_file}")
|
logger.info(f"Feature extractor saved in {output_feature_extractor_file}")
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||||
|
logger.info(f"Feature extractor pushed to the hub in this commit: {url}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_feature_extractor_dict(
|
def get_feature_extractor_dict(
|
||||||
@@ -574,3 +597,9 @@ class FeatureExtractionMixin:
|
|||||||
raise ValueError(f"{auto_class} is not a valid auto class.")
|
raise ValueError(f"{auto_class} is not a valid auto class.")
|
||||||
|
|
||||||
cls._auto_class = auto_class
|
cls._auto_class = auto_class
|
||||||
|
|
||||||
|
|
||||||
|
FeatureExtractionMixin.push_to_hub = copy_func(FeatureExtractionMixin.push_to_hub)
|
||||||
|
FeatureExtractionMixin.push_to_hub.__doc__ = FeatureExtractionMixin.push_to_hub.__doc__.format(
|
||||||
|
object="feature extractor", object_class="AutoFeatureExtractor", object_files="feature extractor file"
|
||||||
|
)
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from huggingface_hub import Repository, delete_repo, login
|
from huggingface_hub import Repository, delete_repo, login
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from transformers import AutoFeatureExtractor
|
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
|
||||||
from transformers.file_utils import is_torch_available, is_vision_available
|
from transformers.file_utils import is_torch_available, is_vision_available
|
||||||
from transformers.testing_utils import PASS, USER, is_staging_test
|
from transformers.testing_utils import PASS, USER, is_staging_test
|
||||||
|
|
||||||
@@ -40,7 +40,6 @@ if is_torch_available():
|
|||||||
if is_vision_available():
|
if is_vision_available():
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
|
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
|
||||||
|
|
||||||
|
|
||||||
@@ -124,11 +123,47 @@ class ConfigPushToHubTester(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, name="test-feature-extractor")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, name="test-feature-extractor-org", organization="valid_org")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
delete_repo(token=cls._token, name="test-dynamic-feature-extractor")
|
delete_repo(token=cls._token, name="test-dynamic-feature-extractor")
|
||||||
except HTTPError:
|
except HTTPError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_push_to_hub(self):
|
||||||
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
feature_extractor.save_pretrained(
|
||||||
|
os.path.join(tmp_dir, "test-feature-extractor"), push_to_hub=True, use_auth_token=self._token
|
||||||
|
)
|
||||||
|
|
||||||
|
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor")
|
||||||
|
for k, v in feature_extractor.__dict__.items():
|
||||||
|
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||||
|
|
||||||
|
def test_push_to_hub_in_organization(self):
|
||||||
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
feature_extractor.save_pretrained(
|
||||||
|
os.path.join(tmp_dir, "test-feature-extractor-org"),
|
||||||
|
push_to_hub=True,
|
||||||
|
use_auth_token=self._token,
|
||||||
|
organization="valid_org",
|
||||||
|
)
|
||||||
|
|
||||||
|
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor-org")
|
||||||
|
for k, v in feature_extractor.__dict__.items():
|
||||||
|
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||||
|
|
||||||
def test_push_to_hub_dynamic_feature_extractor(self):
|
def test_push_to_hub_dynamic_feature_extractor(self):
|
||||||
CustomFeatureExtractor.register_for_auto_class()
|
CustomFeatureExtractor.register_for_auto_class()
|
||||||
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||||
|
|||||||
Reference in New Issue
Block a user