consistent ignore keys + make private (#8737)
* consistent ignore keys + make private * style * - authorized_missing_keys => _keys_to_ignore_on_load_missing - authorized_unexpected_keys => _keys_to_ignore_on_load_unexpected * move public doc of private attributes to private comment
This commit is contained in:
@@ -135,17 +135,17 @@ class ModelTesterMixin:
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
def test_save_load_keys_to_never_save(self):
|
||||
def test_save_load__keys_to_ignore_on_save(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
keys_to_never_save = getattr(model, "keys_to_never_save", None)
|
||||
if keys_to_never_save is None:
|
||||
_keys_to_ignore_on_save = getattr(model, "_keys_to_ignore_on_save", None)
|
||||
if _keys_to_ignore_on_save is None:
|
||||
continue
|
||||
|
||||
# check the keys are in the original state_dict
|
||||
for k in keys_to_never_save:
|
||||
for k in _keys_to_ignore_on_save:
|
||||
self.assertIn(k, model.state_dict())
|
||||
|
||||
# check that certain keys didn't get saved with the model
|
||||
@@ -153,7 +153,7 @@ class ModelTesterMixin:
|
||||
model.save_pretrained(tmpdirname)
|
||||
output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME)
|
||||
state_dict_saved = torch.load(output_model_file)
|
||||
for k in keys_to_never_save:
|
||||
for k in _keys_to_ignore_on_save:
|
||||
self.assertNotIn(k, state_dict_saved)
|
||||
|
||||
def test_initialization(self):
|
||||
|
||||
Reference in New Issue
Block a user