Clean push to hub API (#12187)
* Clean push to hub API * Create working dir if it does not exist * Different tweak * New API + all models + test Flax * Adds the Trainer clean up * Update src/transformers/file_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Address review comments * (nit) output types * No need to set clone_from when folder exists * Update src/transformers/trainer.py Co-authored-by: Julien Chaumond <julien@huggingface.co> * Add generated_from_trainer tag * Update to new version * Fixes Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Julien Chaumond <julien@huggingface.co> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -1274,8 +1274,12 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
|
||||
|
||||
def test_push_to_hub(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
trainer = get_regression_trainer(output_dir=tmp_dir)
|
||||
url = trainer.push_to_hub(repo_name="test-trainer", use_auth_token=self._token)
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=os.path.join(tmp_dir, "test-trainer"),
|
||||
push_to_hub=True,
|
||||
push_to_hub_token=self._token,
|
||||
)
|
||||
url = trainer.push_to_hub()
|
||||
|
||||
# Extract repo_name from the url
|
||||
re_search = re.search(ENDPOINT_STAGING + r"/([^/]+/[^/]+)/", url)
|
||||
@@ -1292,9 +1296,13 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
trainer = get_regression_trainer(output_dir=tmp_dir)
|
||||
trainer.save_model()
|
||||
url = trainer.push_to_hub(
|
||||
repo_name="test-trainer-org", organization="valid_org", use_auth_token=self._token
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=os.path.join(tmp_dir, "test-trainer-org"),
|
||||
push_to_hub=True,
|
||||
push_to_hub_organization="valid_org",
|
||||
push_to_hub_token=self._token,
|
||||
)
|
||||
url = trainer.push_to_hub()
|
||||
|
||||
# Extract repo_name from the url
|
||||
re_search = re.search(ENDPOINT_STAGING + r"/([^/]+/[^/]+)/", url)
|
||||
|
||||
Reference in New Issue
Block a user