From 4083a55ab048da494a218819a09a4590ff634f65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20Zab=C5=82ocki?= Date: Mon, 28 Sep 2020 10:09:26 +0200 Subject: [PATCH] Flos fix (#7384) --- src/transformers/trainer.py | 28 +++++++++++++++++++++------- tests/test_trainer.py | 13 +++++++++++++ 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 5158e5cbbf..3ce676854d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -695,7 +695,7 @@ class Trainer: # set global_step to global_step of last saved checkpoint from model path try: self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0]) - self.total_flos = getattr(model.config, "total_flos", 0) + self.total_flos = getattr(self._actual_model(model).config, "total_flos", 0) epochs_trained = self.global_step // num_update_steps_per_epoch steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch) @@ -1448,15 +1448,29 @@ class Trainer: :obj:`int`: The number of floating-point operations. """ - if isinstance(self.model, torch.nn.DataParallel) or isinstance( - self.model, torch.nn.parallel.DistributedDataParallel - ): - model = self.model.module - else: - model = self.model + model = self._actual_model(self.model) if hasattr(model, "floating_point_ops"): return model.floating_point_ops(inputs) else: return 0 + + @staticmethod + def _actual_model( + model: Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module] + ) -> torch.nn.modules.Module: + """ + + Args: + model: (:obj:`Union[torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel, torch.nn.modules.Module]`): + Model object used during training + + Returns: + :obj:`torch.nn.modules.Module`: unwrapped module + """ + if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel): + model = model.module + else: + model = model + return model diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 395c98cb1e..4613b284bd 100755 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -336,3 +336,16 @@ class TrainerIntegrationTest(unittest.TestCase): trainer = get_regression_trainer(train_len=64, per_device_train_batch_size=16, gradient_accumulation_steps=5) train_output = trainer.train() self.assertEqual(train_output.global_step, int(self.n_epochs)) + + def test_flos_extraction(self): + trainer = get_regression_trainer(learning_rate=0.1) + + def assert_flos_extraction(trainer, wrapped_model_to_check): + self.assertEqual(trainer.model, trainer._actual_model(wrapped_model_to_check)) + self.assertGreaterEqual(getattr(trainer._actual_model(wrapped_model_to_check).config, "total_flos", 0), 0) + + # with plain model + assert_flos_extraction(trainer, trainer.model) + + # with enforced DataParallel + assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model))