Add push to hub to feature extractor (#15632)

* Add push to hub to feature extractor

* Quality

* Clean up
This commit is contained in:
Sylvain Gugger
2022-02-11 17:14:01 -05:00
committed by GitHub
parent 4f403ea899
commit 52d2e6f6e9
2 changed files with 69 additions and 5 deletions

View File

@@ -23,7 +23,7 @@ from pathlib import Path
from huggingface_hub import Repository, delete_repo, login
from requests.exceptions import HTTPError
from transformers import AutoFeatureExtractor
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
from transformers.file_utils import is_torch_available, is_vision_available
from transformers.testing_utils import PASS, USER, is_staging_test
@@ -40,7 +40,6 @@ if is_torch_available():
if is_vision_available():
from PIL import Image
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
@@ -124,11 +123,47 @@ class ConfigPushToHubTester(unittest.TestCase):
@classmethod
def tearDownClass(cls):
try:
delete_repo(token=cls._token, name="test-feature-extractor")
except HTTPError:
pass
try:
delete_repo(token=cls._token, name="test-feature-extractor-org", organization="valid_org")
except HTTPError:
pass
try:
delete_repo(token=cls._token, name="test-dynamic-feature-extractor")
except HTTPError:
pass
def test_push_to_hub(self):
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir:
feature_extractor.save_pretrained(
os.path.join(tmp_dir, "test-feature-extractor"), push_to_hub=True, use_auth_token=self._token
)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor")
for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k))
def test_push_to_hub_in_organization(self):
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir:
feature_extractor.save_pretrained(
os.path.join(tmp_dir, "test-feature-extractor-org"),
push_to_hub=True,
use_auth_token=self._token,
organization="valid_org",
)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor-org")
for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k))
def test_push_to_hub_dynamic_feature_extractor(self):
CustomFeatureExtractor.register_for_auto_class()
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)