Flos fix (#7384)
This commit is contained in:
@@ -336,3 +336,16 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
trainer = get_regression_trainer(train_len=64, per_device_train_batch_size=16, gradient_accumulation_steps=5)
|
||||
train_output = trainer.train()
|
||||
self.assertEqual(train_output.global_step, int(self.n_epochs))
|
||||
|
||||
def test_flos_extraction(self):
|
||||
trainer = get_regression_trainer(learning_rate=0.1)
|
||||
|
||||
def assert_flos_extraction(trainer, wrapped_model_to_check):
|
||||
self.assertEqual(trainer.model, trainer._actual_model(wrapped_model_to_check))
|
||||
self.assertGreaterEqual(getattr(trainer._actual_model(wrapped_model_to_check).config, "total_flos", 0), 0)
|
||||
|
||||
# with plain model
|
||||
assert_flos_extraction(trainer, trainer.model)
|
||||
|
||||
# with enforced DataParallel
|
||||
assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model))
|
||||
|
||||
Reference in New Issue
Block a user