From dca67968769534f97167abdc08401cc0f53a9005 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 11 Oct 2021 15:34:01 +0200 Subject: [PATCH] [Gradient checkpoining] Correct disabling `find_unused_parameters` in Trainer when gradient checkpointing is enabled (#13961) * up * correct test --- src/transformers/modeling_utils.py | 14 ++++++++++++-- src/transformers/trainer.py | 2 +- tests/test_modeling_common.py | 19 +++++++++++++++++++ 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 783a5537e2..c84aee206d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -946,7 +946,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix self.base_model._prune_heads(heads_to_prune) - def gradient_checkpointing_enable(self, flag: bool = True): + def gradient_checkpointing_enable(self): """ Activates gradient checkpointing for the current model. @@ -957,7 +957,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") self.apply(partial(self._set_gradient_checkpointing, value=True)) - def gradient_checkpointing_disable(self, flag: bool = True): + def gradient_checkpointing_disable(self): """ Deactivates gradient checkpointing for the current model. @@ -967,6 +967,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if self.supports_gradient_checkpointing: self.apply(partial(self._set_gradient_checkpointing, value=False)) + @property + def is_gradient_checkpointing(self) -> bool: + """ + Whether gradient checkpointing is activated for this model or not. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e59b5982f8..b3b2413340 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -996,7 +996,7 @@ class Trainer: elif isinstance(model, PreTrainedModel): # find_unused_parameters breaks checkpointing as per # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 - find_unused_parameters = not getattr(model.config, "_gradient_checkpointing", False) + find_unused_parameters = not model.is_gradient_checkpointing else: find_unused_parameters = True model = nn.parallel.DistributedDataParallel( diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 7c1ca8dd44..f946d59017 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -197,6 +197,25 @@ class ModelTesterMixin: ) self.assertTrue(len(load_result.unexpected_keys) == 0) + def test_gradient_checkpointing_enable_disable(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if not model_class.supports_gradient_checkpointing: + continue + + # at init model should have gradient checkpointing disabled + model = model_class(config) + self.assertFalse(model.is_gradient_checkpointing) + + # check enable works + model.gradient_checkpointing_enable() + self.assertTrue(model.is_gradient_checkpointing) + + # check disable works + model.gradient_checkpointing_disable() + self.assertFalse(model.is_gradient_checkpointing) + def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: module.weight.data.fill_(3)