Fix tests due to breaking change in accelerate (#39451)
* update values * fix
This commit is contained in:
@@ -3394,7 +3394,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
)
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()])
|
||||
trainer.train()
|
||||
self.assertEqual(trainer._train_batch_size, 8)
|
||||
self.assertEqual(trainer._train_batch_size, 14)
|
||||
|
||||
def test_auto_batch_size_with_resume_from_checkpoint(self):
|
||||
train_dataset = RegressionDataset(length=128)
|
||||
@@ -3414,16 +3414,16 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
)
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()])
|
||||
trainer.train()
|
||||
# After `auto_find_batch_size` is ran we should now be at 8
|
||||
self.assertEqual(trainer._train_batch_size, 8)
|
||||
# After `auto_find_batch_size` is ran we should now be at 16*0.9=14
|
||||
self.assertEqual(trainer._train_batch_size, 14)
|
||||
|
||||
# We can then make a new Trainer
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset)
|
||||
# Check we are at 16 to start
|
||||
self.assertEqual(trainer._train_batch_size, 16 * max(trainer.args.n_gpu, 1))
|
||||
trainer.train(resume_from_checkpoint=True)
|
||||
# We should be back to 8 again, picking up based upon the last ran Trainer
|
||||
self.assertEqual(trainer._train_batch_size, 8)
|
||||
# We should be back to 14 again, picking up based upon the last ran Trainer
|
||||
self.assertEqual(trainer._train_batch_size, 14)
|
||||
|
||||
# regression for this issue: https://github.com/huggingface/transformers/issues/12970
|
||||
def test_training_with_resume_from_checkpoint_false(self):
|
||||
|
||||
@@ -464,7 +464,7 @@ class TrainerUtilsTest(unittest.TestCase):
|
||||
raise RuntimeError("CUDA out of memory.")
|
||||
|
||||
mock_training_loop_function()
|
||||
self.assertEqual(batch_sizes, [64, 32, 16])
|
||||
self.assertEqual(batch_sizes, [64, 57, 51, 45, 40, 36, 32, 28, 25, 22, 19, 17, 15])
|
||||
|
||||
@require_accelerate
|
||||
def test_executable_batch_size_no_search(self):
|
||||
|
||||
Reference in New Issue
Block a user