Test checkpointing (#11682)

* Add test and see where CI is unhappy

* Load with strict=False
This commit is contained in:
Sylvain Gugger
2021-05-11 12:02:48 -04:00
committed by GitHub
parent d9b286272c
commit f13f1f8fb8
2 changed files with 19 additions and 1 deletions

View File

@@ -177,6 +177,13 @@ class ModelTesterMixin:
for k in _keys_to_ignore_on_save:
self.assertNotIn(k, state_dict_saved)
# Test we can load the state dict in the model, necessary for the checkpointing API in Trainer.
load_result = model.load_state_dict(state_dict_saved, strict=False)
self.assertTrue(
len(load_result.missing_keys) == 0 or load_result.missing_keys == model._keys_to_ignore_on_save
)
self.assertTrue(len(load_result.unexpected_keys) == 0)
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)