Fix trainer test wrt DeepSpeed + auto_find_bs (#29061)
* FIx trainer test * Update tests/trainer/test_trainer.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -1588,18 +1588,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
auto_find_batch_size=True,
|
auto_find_batch_size=True,
|
||||||
deepspeed=deepspeed,
|
deepspeed=deepspeed,
|
||||||
)
|
)
|
||||||
trainer = Trainer(model, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()])
|
# Note: This can have issues, for now we don't support this functionality
|
||||||
trainer.train()
|
# ref: https://github.com/huggingface/transformers/pull/29057
|
||||||
# After `auto_find_batch_size` is ran we should now be at 8
|
with self.assertRaises(NotImplementedError):
|
||||||
self.assertEqual(trainer._train_batch_size, 8)
|
_ = Trainer(model, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()])
|
||||||
|
|
||||||
# We can then make a new Trainer
|
|
||||||
trainer = Trainer(model, args, train_dataset=train_dataset)
|
|
||||||
# Check we are at 16 to start
|
|
||||||
self.assertEqual(trainer._train_batch_size, 16 * max(trainer.args.n_gpu, 1))
|
|
||||||
trainer.train(resume_from_checkpoint=True)
|
|
||||||
# We should be back to 8 again, picking up based upon the last ran Trainer
|
|
||||||
self.assertEqual(trainer._train_batch_size, 8)
|
|
||||||
|
|
||||||
def test_auto_batch_size_with_resume_from_checkpoint(self):
|
def test_auto_batch_size_with_resume_from_checkpoint(self):
|
||||||
train_dataset = RegressionDataset(length=128)
|
train_dataset = RegressionDataset(length=128)
|
||||||
|
|||||||
Reference in New Issue
Block a user