[multiple models] skip saving/loading deterministic state_dict keys (#7878)

* make the save_load special key tests common

* handle mbart

* cleaner solution

* fix

* move test_save_load_missing_keys back into fstm for now

* restore

* style

* add marian

* add pegasus

* blenderbot

* revert - no static embed
This commit is contained in:
Stas Bekman
2020-10-21 05:06:07 -07:00
committed by GitHub
parent 006a16483f
commit 57516c0cc8
8 changed files with 144 additions and 26 deletions

View File

@@ -22,6 +22,7 @@ import unittest
from typing import List, Tuple
from transformers import is_torch_available
from transformers.file_utils import WEIGHTS_NAME
from transformers.testing_utils import require_torch, require_torch_multigpu, slow, torch_device
@@ -129,6 +130,27 @@ class ModelTesterMixin:
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def test_save_load_keys_to_never_save(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
keys_to_never_save = getattr(model, "keys_to_never_save", None)
if keys_to_never_save is None:
continue
# check the keys are in the original state_dict
for k in keys_to_never_save:
self.assertIn(k, model.state_dict())
# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME)
state_dict_saved = torch.load(output_model_file)
for k in keys_to_never_save:
self.assertNotIn(k, state_dict_saved)
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()