Fix gradient_checkpointing backward compatibility (#14408)
* Fix gradient_checkpointing backward compatibility * Remove needless line * make sure mask prob is big enough and length small enough * Fix tests Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -213,6 +213,25 @@ class ModelTesterMixin:
|
||||
)
|
||||
self.assertTrue(len(load_result.unexpected_keys) == 0)
|
||||
|
||||
def test_gradient_checkpointing_backward_compatibility(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class.supports_gradient_checkpointing:
|
||||
continue
|
||||
|
||||
config.gradient_checkpointing = True
|
||||
model = model_class(config)
|
||||
# Model does not have gradient checkpointing activated yet, it will be done at the first forward.
|
||||
self.assertFalse(model.is_gradient_checkpointing)
|
||||
|
||||
model.to(torch_device)
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
_ = model(**inputs)
|
||||
|
||||
# Model has gradient checkpointing activated after the first forward.
|
||||
self.assertTrue(model.is_gradient_checkpointing)
|
||||
|
||||
def test_gradient_checkpointing_enable_disable(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -418,6 +437,7 @@ class ModelTesterMixin:
|
||||
continue
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.gradient_checkpointing_enable()
|
||||
model.train()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
loss = model(**inputs).loss
|
||||
|
||||
Reference in New Issue
Block a user