[Gradient checkpoining] Correct disabling find_unused_parameters in Trainer when gradient checkpointing is enabled (#13961)

* up

* correct test
This commit is contained in:
Patrick von Platen
2021-10-11 15:34:01 +02:00
committed by GitHub
parent 4a18337bae
commit dca6796876
3 changed files with 32 additions and 3 deletions

View File

@@ -946,7 +946,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
self.base_model._prune_heads(heads_to_prune)
def gradient_checkpointing_enable(self, flag: bool = True):
def gradient_checkpointing_enable(self):
"""
Activates gradient checkpointing for the current model.
@@ -957,7 +957,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
self.apply(partial(self._set_gradient_checkpointing, value=True))
def gradient_checkpointing_disable(self, flag: bool = True):
def gradient_checkpointing_disable(self):
"""
Deactivates gradient checkpointing for the current model.
@@ -967,6 +967,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if self.supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False))
@property
def is_gradient_checkpointing(self) -> bool:
"""
Whether gradient checkpointing is activated for this model or not.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
"""
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],

View File

@@ -996,7 +996,7 @@ class Trainer:
elif isinstance(model, PreTrainedModel):
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
find_unused_parameters = not getattr(model.config, "_gradient_checkpointing", False)
find_unused_parameters = not model.is_gradient_checkpointing
else:
find_unused_parameters = True
model = nn.parallel.DistributedDataParallel(

View File

@@ -197,6 +197,25 @@ class ModelTesterMixin:
)
self.assertTrue(len(load_result.unexpected_keys) == 0)
def test_gradient_checkpointing_enable_disable(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if not model_class.supports_gradient_checkpointing:
continue
# at init model should have gradient checkpointing disabled
model = model_class(config)
self.assertFalse(model.is_gradient_checkpointing)
# check enable works
model.gradient_checkpointing_enable()
self.assertTrue(model.is_gradient_checkpointing)
# check disable works
model.gradient_checkpointing_disable()
self.assertFalse(model.is_gradient_checkpointing)
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)