Fix test_number_of_steps_in_training_with_ipex (#17889)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-06-28 08:55:02 +02:00
committed by GitHub
parent 0b0dd97737
commit f717d47fe0

View File

@@ -649,14 +649,14 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# Regular training has n_epochs * len(train_dl) steps # Regular training has n_epochs * len(train_dl) steps
trainer = get_regression_trainer(learning_rate=0.1, use_ipex=True, bf16=mix_bf16, no_cuda=True) trainer = get_regression_trainer(learning_rate=0.1, use_ipex=True, bf16=mix_bf16, no_cuda=True)
train_output = trainer.train() train_output = trainer.train()
self.assertEqual(train_output.global_step, self.n_epochs * 64 / self.batch_size) self.assertEqual(train_output.global_step, self.n_epochs * 64 / trainer.args.train_batch_size)
# Check passing num_train_epochs works (and a float version too): # Check passing num_train_epochs works (and a float version too):
trainer = get_regression_trainer( trainer = get_regression_trainer(
learning_rate=0.1, num_train_epochs=1.5, use_ipex=True, bf16=mix_bf16, no_cuda=True learning_rate=0.1, num_train_epochs=1.5, use_ipex=True, bf16=mix_bf16, no_cuda=True
) )
train_output = trainer.train() train_output = trainer.train()
self.assertEqual(train_output.global_step, int(1.5 * 64 / self.batch_size)) self.assertEqual(train_output.global_step, int(1.5 * 64 / trainer.args.train_batch_size))
# If we pass a max_steps, num_train_epochs is ignored # If we pass a max_steps, num_train_epochs is ignored
trainer = get_regression_trainer( trainer = get_regression_trainer(