Fix gradient_checkpointing backward compatibility (#14408)

* Fix gradient_checkpointing backward compatibility

* Remove needless line

* make sure mask prob is big enough and length small enough

* Fix tests

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
Sylvain Gugger
2021-11-16 08:58:42 -05:00
committed by GitHub
parent 1cc453d33c
commit 040fd47162
8 changed files with 47 additions and 6 deletions

View File

@@ -367,6 +367,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
if model_class.__name__ == "DeiTForImageClassificationWithTeacher":
continue
model = model_class(config)
model.gradient_checkpointing_enable()
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)