Fix gradient_checkpointing backward compatibility (#14408)
* Fix gradient_checkpointing backward compatibility * Remove needless line * make sure mask prob is big enough and length small enough * Fix tests Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -367,6 +367,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if model_class.__name__ == "DeiTForImageClassificationWithTeacher":
|
||||
continue
|
||||
model = model_class(config)
|
||||
model.gradient_checkpointing_enable()
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
|
||||
Reference in New Issue
Block a user