Add a post init method to all models (#14431)

* Add a post init method to all models

* Fix tests

* Fix last tests

* Fix templates

* Add comment

* Forgot to save
This commit is contained in:
Sylvain Gugger
2021-11-18 08:38:09 -05:00
committed by GitHub
parent 08816de16a
commit d83b0e0c07
70 changed files with 693 additions and 359 deletions

View File

@@ -222,14 +222,6 @@ class ModelTesterMixin:
config.gradient_checkpointing = True
model = model_class(config)
# Model does not have gradient checkpointing activated yet, it will be done at the first forward.
self.assertFalse(model.is_gradient_checkpointing)
model.to(torch_device)
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
_ = model(**inputs)
# Model has gradient checkpointing activated after the first forward.
self.assertTrue(model.is_gradient_checkpointing)
def test_gradient_checkpointing_enable_disable(self):