@@ -1458,6 +1458,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer.train(resume_from_checkpoint=True)
|
||||
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))
|
||||
|
||||
@unittest.skip(
|
||||
reason="@muellerzr: Fix once Trainer can take an accelerate configuration. Need to set `seedable_sampler=True`."
|
||||
)
|
||||
def test_resume_training_with_randomness(self):
|
||||
# For more than 1 GPUs, since the randomness is introduced in the model and with DataParallel (which is used
|
||||
# in this test for more than 2 GPUs), the calls to the torch RNG will happen in a random order (sometimes
|
||||
|
||||
Reference in New Issue
Block a user