@@ -318,9 +318,10 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
yes_grad_accum_b = yes_grad_accum_trainer.model.b.item()
|
||||
self.assertNotEqual(yes_grad_accum_a, a)
|
||||
|
||||
# training with half the batch size but accumulation steps as 2 should give the same weights
|
||||
self.assertEqual(no_grad_accum_a, yes_grad_accum_a)
|
||||
self.assertEqual(no_grad_accum_b, yes_grad_accum_b)
|
||||
# training with half the batch size but accumulation steps as 2 should give the same
|
||||
# weights, but sometimes get a slight difference still of 1e-6
|
||||
self.assertAlmostEqual(no_grad_accum_a, yes_grad_accum_a, places=5)
|
||||
self.assertAlmostEqual(no_grad_accum_b, yes_grad_accum_b, places=5)
|
||||
|
||||
# see the note above how to get identical loss on a small bs
|
||||
self.assertAlmostEqual(no_grad_accum_loss, yes_grad_accum_loss, places=5)
|
||||
|
||||
Reference in New Issue
Block a user