From 315e67404d24c971cd690541f34651d7bc635cec Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 9 Feb 2022 12:27:59 -0500 Subject: [PATCH] Fix tests hub failure (#15580) * Expose hub test problem * Fix tests --- tests/test_trainer.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 4c4ecb54c1..a730c0df6d 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -21,6 +21,7 @@ import random import re import subprocess import tempfile +import time import unittest from pathlib import Path from unittest.mock import Mock, patch @@ -1544,12 +1545,17 @@ class TrainerIntegrationWithHubTester(unittest.TestCase): ) trainer.train() + # Wait for the async pushes to be finished + while trainer.push_in_progress is not None and not trainer.push_in_progress.is_done: + time.sleep(0.5) + with tempfile.TemporaryDirectory() as tmp_dir: _ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-epoch", use_auth_token=self._token) commits = self.get_commit_history(tmp_dir) - expected_commits = [f"Training in progress, epoch {i}" for i in range(3, 0, -1)] - expected_commits.append("initial commit") - self.assertListEqual(commits, expected_commits) + self.assertIn("initial commit", commits) + # We can't test that epoch 2 and 3 are in the commits without being flaky as those might be skipped if + # the push for epoch 1 wasn't finished at the time. + self.assertIn("Training in progress, epoch 1", commits) def test_push_to_hub_with_saves_each_n_steps(self): num_gpus = max(1, get_gpu_count()) @@ -1566,13 +1572,17 @@ class TrainerIntegrationWithHubTester(unittest.TestCase): ) trainer.train() + # Wait for the async pushes to be finished + while trainer.push_in_progress is not None and not trainer.push_in_progress.is_done: + time.sleep(0.5) + with tempfile.TemporaryDirectory() as tmp_dir: _ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-step", use_auth_token=self._token) commits = self.get_commit_history(tmp_dir) - total_steps = 20 // num_gpus - expected_commits = [f"Training in progress, step {i}" for i in range(total_steps, 0, -5)] - expected_commits.append("initial commit") - self.assertListEqual(commits, expected_commits) + self.assertIn("initial commit", commits) + # We can't test that epoch 2 and 3 are in the commits without being flaky as those might be skipped if + # the push for epoch 1 wasn't finished at the time. + self.assertIn("Training in progress, step 5", commits) @require_torch