Add push to hub to feature extractor (#15632)
* Add push to hub to feature extractor * Quality * Clean up
This commit is contained in:
@@ -23,7 +23,7 @@ from pathlib import Path
|
||||
|
||||
from huggingface_hub import Repository, delete_repo, login
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import AutoFeatureExtractor
|
||||
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
|
||||
from transformers.file_utils import is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import PASS, USER, is_staging_test
|
||||
|
||||
@@ -40,7 +40,6 @@ if is_torch_available():
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
|
||||
|
||||
|
||||
@@ -124,11 +123,47 @@ class ConfigPushToHubTester(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
delete_repo(token=cls._token, name="test-feature-extractor")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, name="test-feature-extractor-org", organization="valid_org")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, name="test-dynamic-feature-extractor")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
def test_push_to_hub(self):
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
feature_extractor.save_pretrained(
|
||||
os.path.join(tmp_dir, "test-feature-extractor"), push_to_hub=True, use_auth_token=self._token
|
||||
)
|
||||
|
||||
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor")
|
||||
for k, v in feature_extractor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
feature_extractor.save_pretrained(
|
||||
os.path.join(tmp_dir, "test-feature-extractor-org"),
|
||||
push_to_hub=True,
|
||||
use_auth_token=self._token,
|
||||
organization="valid_org",
|
||||
)
|
||||
|
||||
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor-org")
|
||||
for k, v in feature_extractor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||
|
||||
def test_push_to_hub_dynamic_feature_extractor(self):
|
||||
CustomFeatureExtractor.register_for_auto_class()
|
||||
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||
|
||||
Reference in New Issue
Block a user