Test checkpointing (#11682)
* Add test and see where CI is unhappy * Load with strict=False
This commit is contained in:
@@ -177,6 +177,13 @@ class ModelTesterMixin:
|
||||
for k in _keys_to_ignore_on_save:
|
||||
self.assertNotIn(k, state_dict_saved)
|
||||
|
||||
# Test we can load the state dict in the model, necessary for the checkpointing API in Trainer.
|
||||
load_result = model.load_state_dict(state_dict_saved, strict=False)
|
||||
self.assertTrue(
|
||||
len(load_result.missing_keys) == 0 or load_result.missing_keys == model._keys_to_ignore_on_save
|
||||
)
|
||||
self.assertTrue(len(load_result.unexpected_keys) == 0)
|
||||
|
||||
def _mock_init_weights(self, module):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data.fill_(3)
|
||||
|
||||
Reference in New Issue
Block a user