Fix gradient_checkpointing backward compatibility (#14408)

* Fix gradient_checkpointing backward compatibility

* Remove needless line

* make sure mask prob is big enough and length small enough

* Fix tests

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
Sylvain Gugger
2021-11-16 08:58:42 -05:00
committed by GitHub
parent 1cc453d33c
commit 040fd47162
8 changed files with 47 additions and 6 deletions

View File

@@ -213,6 +213,25 @@ class ModelTesterMixin:
)
self.assertTrue(len(load_result.unexpected_keys) == 0)
def test_gradient_checkpointing_backward_compatibility(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if not model_class.supports_gradient_checkpointing:
continue
config.gradient_checkpointing = True
model = model_class(config)
# Model does not have gradient checkpointing activated yet, it will be done at the first forward.
self.assertFalse(model.is_gradient_checkpointing)
model.to(torch_device)
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
_ = model(**inputs)
# Model has gradient checkpointing activated after the first forward.
self.assertTrue(model.is_gradient_checkpointing)
def test_gradient_checkpointing_enable_disable(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -418,6 +437,7 @@ class ModelTesterMixin:
continue
model = model_class(config)
model.to(torch_device)
model.gradient_checkpointing_enable()
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss