Push to hub when saving checkpoints (#13503)

* Push to hub when saving checkpoints

* Add model card

* Revert partial model card

* Small fix for checkpoint

* Add tests

* Add documentation

* Fix tests

* Bump huggingface_hub

* Fix test
This commit is contained in:
Sylvain Gugger
2021-09-14 08:02:15 -04:00
committed by GitHub
parent 51e5eca612
commit 3081d3868e
7 changed files with 227 additions and 44 deletions

View File

@@ -18,13 +18,14 @@ import gc
import os
import random
import re
import subprocess
import tempfile
import unittest
from pathlib import Path
import numpy as np
from huggingface_hub import HfApi
from huggingface_hub import HfApi, Repository
from requests.exceptions import HTTPError
from transformers import (
AutoTokenizer,
@@ -1284,10 +1285,11 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
@classmethod
def tearDownClass(cls):
try:
cls._api.delete_repo(token=cls._token, name="test-trainer")
except HTTPError:
pass
for model in ["test-trainer", "test-trainer-epoch", "test-trainer-step"]:
try:
cls._api.delete_repo(token=cls._token, name=model)
except HTTPError:
pass
try:
cls._api.delete_repo(token=cls._token, name="test-trainer-org", organization="valid_org")
@@ -1336,6 +1338,55 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
self.assertEqual(model.a.item(), trainer.model.a.item())
self.assertEqual(model.b.item(), trainer.model.b.item())
def get_commit_history(self, repo):
commit_logs = subprocess.run(
"git log".split(),
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
check=True,
encoding="utf-8",
cwd=repo,
).stdout
commits = commit_logs.split("\n\n")[1::2]
return [commit.strip() for commit in commits]
def test_push_to_hub_with_saves_each_epoch(self):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(
output_dir=os.path.join(tmp_dir, "test-trainer-epoch"),
push_to_hub=True,
hub_token=self._token,
save_strategy="epoch",
)
trainer.train()
with tempfile.TemporaryDirectory() as tmp_dir:
_ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-epoch", use_auth_token=self._token)
commits = self.get_commit_history(tmp_dir)
expected_commits = [f"Training in progress, epoch {i}" for i in range(3, 0, -1)]
expected_commits.append("initial commit")
self.assertListEqual(commits, expected_commits)
print(commits, len(commits))
def test_push_to_hub_with_saves_each_n_steps(self):
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = get_regression_trainer(
output_dir=os.path.join(tmp_dir, "test-trainer-step"),
push_to_hub=True,
hub_token=self._token,
save_strategy="steps",
save_steps=5,
)
trainer.train()
with tempfile.TemporaryDirectory() as tmp_dir:
_ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-step", use_auth_token=self._token)
commits = self.get_commit_history(tmp_dir)
expected_commits = [f"Training in progress, step {i}" for i in range(20, 0, -5)]
expected_commits.append("initial commit")
self.assertListEqual(commits, expected_commits)
print(commits, len(commits))
@require_torch
@require_optuna