Add push_to_hub method to processors (#15668)

* Add push_to_hub method to processors

* Fix test

* The other one too!
This commit is contained in:
Sylvain Gugger
2022-02-15 21:14:04 -05:00
committed by GitHub
parent bee361c6f1
commit 2d02f7b29b
2 changed files with 73 additions and 4 deletions

View File

@@ -41,7 +41,7 @@ SAMPLE_PROCESSOR_CONFIG = os.path.join(
)
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/vocab.json")
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
SAMPLE_PROCESSOR_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
class AutoFeatureExtractorTest(unittest.TestCase):
@@ -165,17 +165,55 @@ class ProcessorPushToHubTester(unittest.TestCase):
@classmethod
def tearDownClass(cls):
try:
delete_repo(token=cls._token, name="test-processor")
except HTTPError:
pass
try:
delete_repo(token=cls._token, name="test-processor-org", organization="valid_org")
except HTTPError:
pass
try:
delete_repo(token=cls._token, name="test-dynamic-processor")
except HTTPError:
pass
def test_push_to_hub(self):
processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir:
processor.save_pretrained(
os.path.join(tmp_dir, "test-processor"), push_to_hub=True, use_auth_token=self._token
)
new_processor = Wav2Vec2Processor.from_pretrained(f"{USER}/test-processor")
for k, v in processor.feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_processor.feature_extractor, k))
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab())
def test_push_to_hub_in_organization(self):
processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir:
processor.save_pretrained(
os.path.join(tmp_dir, "test-processor-org"),
push_to_hub=True,
use_auth_token=self._token,
organization="valid_org",
)
new_processor = Wav2Vec2Processor.from_pretrained("valid_org/test-processor-org")
for k, v in processor.feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_processor.feature_extractor, k))
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab())
def test_push_to_hub_dynamic_processor(self):
CustomFeatureExtractor.register_for_auto_class()
CustomTokenizer.register_for_auto_class()
CustomProcessor.register_for_auto_class()
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt")