From 636b03244cb3c5bac6d12a5a968d5024e0fde7c3 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Fri, 16 Feb 2024 10:04:24 -0500 Subject: [PATCH] Fix trainer test wrt DeepSpeed + auto_find_bs (#29061) * FIx trainer test * Update tests/trainer/test_trainer.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- tests/trainer/test_trainer.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index b03423bde2..87e95a7ea3 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1588,18 +1588,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): auto_find_batch_size=True, deepspeed=deepspeed, ) - trainer = Trainer(model, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()]) - trainer.train() - # After `auto_find_batch_size` is ran we should now be at 8 - self.assertEqual(trainer._train_batch_size, 8) - - # We can then make a new Trainer - trainer = Trainer(model, args, train_dataset=train_dataset) - # Check we are at 16 to start - self.assertEqual(trainer._train_batch_size, 16 * max(trainer.args.n_gpu, 1)) - trainer.train(resume_from_checkpoint=True) - # We should be back to 8 again, picking up based upon the last ran Trainer - self.assertEqual(trainer._train_batch_size, 8) + # Note: This can have issues, for now we don't support this functionality + # ref: https://github.com/huggingface/transformers/pull/29057 + with self.assertRaises(NotImplementedError): + _ = Trainer(model, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()]) def test_auto_batch_size_with_resume_from_checkpoint(self): train_dataset = RegressionDataset(length=128)