@@ -21,6 +21,7 @@ import random
|
|||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
@@ -1544,12 +1545,17 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
# Wait for the async pushes to be finished
|
||||||
|
while trainer.push_in_progress is not None and not trainer.push_in_progress.is_done:
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
_ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-epoch", use_auth_token=self._token)
|
_ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-epoch", use_auth_token=self._token)
|
||||||
commits = self.get_commit_history(tmp_dir)
|
commits = self.get_commit_history(tmp_dir)
|
||||||
expected_commits = [f"Training in progress, epoch {i}" for i in range(3, 0, -1)]
|
self.assertIn("initial commit", commits)
|
||||||
expected_commits.append("initial commit")
|
# We can't test that epoch 2 and 3 are in the commits without being flaky as those might be skipped if
|
||||||
self.assertListEqual(commits, expected_commits)
|
# the push for epoch 1 wasn't finished at the time.
|
||||||
|
self.assertIn("Training in progress, epoch 1", commits)
|
||||||
|
|
||||||
def test_push_to_hub_with_saves_each_n_steps(self):
|
def test_push_to_hub_with_saves_each_n_steps(self):
|
||||||
num_gpus = max(1, get_gpu_count())
|
num_gpus = max(1, get_gpu_count())
|
||||||
@@ -1566,13 +1572,17 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
# Wait for the async pushes to be finished
|
||||||
|
while trainer.push_in_progress is not None and not trainer.push_in_progress.is_done:
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
_ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-step", use_auth_token=self._token)
|
_ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-step", use_auth_token=self._token)
|
||||||
commits = self.get_commit_history(tmp_dir)
|
commits = self.get_commit_history(tmp_dir)
|
||||||
total_steps = 20 // num_gpus
|
self.assertIn("initial commit", commits)
|
||||||
expected_commits = [f"Training in progress, step {i}" for i in range(total_steps, 0, -5)]
|
# We can't test that epoch 2 and 3 are in the commits without being flaky as those might be skipped if
|
||||||
expected_commits.append("initial commit")
|
# the push for epoch 1 wasn't finished at the time.
|
||||||
self.assertListEqual(commits, expected_commits)
|
self.assertIn("Training in progress, step 5", commits)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
Reference in New Issue
Block a user