[Gradient checkpoining] Correct disabling find_unused_parameters in Trainer when gradient checkpointing is enabled (#13961)

* up

* correct test
This commit is contained in:
Patrick von Platen
2021-10-11 15:34:01 +02:00
committed by GitHub
parent 4a18337bae
commit dca6796876
3 changed files with 32 additions and 3 deletions

View File

@@ -197,6 +197,25 @@ class ModelTesterMixin:
)
self.assertTrue(len(load_result.unexpected_keys) == 0)
def test_gradient_checkpointing_enable_disable(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if not model_class.supports_gradient_checkpointing:
continue
# at init model should have gradient checkpointing disabled
model = model_class(config)
self.assertFalse(model.is_gradient_checkpointing)
# check enable works
model.gradient_checkpointing_enable()
self.assertTrue(model.is_gradient_checkpointing)
# check disable works
model.gradient_checkpointing_disable()
self.assertFalse(model.is_gradient_checkpointing)
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)