Clean push to hub API (#12187)

* Clean push to hub API

* Create working dir if it does not exist

* Different tweak

* New API + all models + test Flax

* Adds the Trainer clean up

* Update src/transformers/file_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Address review comments

* (nit) output types

* No need to set clone_from when folder exists

* Update src/transformers/trainer.py

Co-authored-by: Julien Chaumond <julien@huggingface.co>

* Add generated_from_trainer tag

* Update to new version

* Fixes

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Julien Chaumond <julien@huggingface.co>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Sylvain Gugger
2021-06-23 10:11:19 -04:00
committed by GitHub
parent 625f512d5e
commit 53c60babe4
17 changed files with 368 additions and 159 deletions

View File

@@ -1274,8 +1274,12 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(output_dir=tmp_dir)
url = trainer.push_to_hub(repo_name="test-trainer", use_auth_token=self._token)
trainer = get_regression_trainer(
output_dir=os.path.join(tmp_dir, "test-trainer"),
push_to_hub=True,
push_to_hub_token=self._token,
)
url = trainer.push_to_hub()
# Extract repo_name from the url
re_search = re.search(ENDPOINT_STAGING + r"/([^/]+/[^/]+)/", url)
@@ -1292,9 +1296,13 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(output_dir=tmp_dir)
trainer.save_model()
url = trainer.push_to_hub(
repo_name="test-trainer-org", organization="valid_org", use_auth_token=self._token
trainer = get_regression_trainer(
output_dir=os.path.join(tmp_dir, "test-trainer-org"),
push_to_hub=True,
push_to_hub_organization="valid_org",
push_to_hub_token=self._token,
)
url = trainer.push_to_hub()
# Extract repo_name from the url
re_search = re.search(ENDPOINT_STAGING + r"/([^/]+/[^/]+)/", url)