This commit is contained in:
Marcin Zabłocki
2020-09-28 10:09:26 +02:00
committed by GitHub
parent ae3e84f3ba
commit 4083a55ab0
2 changed files with 34 additions and 7 deletions

View File

@@ -336,3 +336,16 @@ class TrainerIntegrationTest(unittest.TestCase):
trainer = get_regression_trainer(train_len=64, per_device_train_batch_size=16, gradient_accumulation_steps=5)
train_output = trainer.train()
self.assertEqual(train_output.global_step, int(self.n_epochs))
def test_flos_extraction(self):
trainer = get_regression_trainer(learning_rate=0.1)
def assert_flos_extraction(trainer, wrapped_model_to_check):
self.assertEqual(trainer.model, trainer._actual_model(wrapped_model_to_check))
self.assertGreaterEqual(getattr(trainer._actual_model(wrapped_model_to_check).config, "total_flos", 0), 0)
# with plain model
assert_flos_extraction(trainer, trainer.model)
# with enforced DataParallel
assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model))