[multiple models] skip saving/loading deterministic state_dict keys (#7878)
* make the save_load special key tests common * handle mbart * cleaner solution * fix * move test_save_load_missing_keys back into fstm for now * restore * style * add marian * add pegasus * blenderbot * revert - no static embed
This commit is contained in:
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@@ -21,7 +20,7 @@ import timeout_decorator # noqa
|
||||
|
||||
from parameterized import parameterized
|
||||
from transformers import is_torch_available
|
||||
from transformers.file_utils import WEIGHTS_NAME, cached_property
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
@@ -203,8 +202,9 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
)[0]
|
||||
_assert_tensors_equal(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask)
|
||||
|
||||
def test_save_load_strict(self):
|
||||
def test_save_load_missing_keys(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
@@ -213,27 +213,6 @@ class FSMTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
|
||||
self.assertEqual(info["missing_keys"], [])
|
||||
|
||||
def test_save_load_no_save_keys(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
state_dict_no_save_keys = getattr(model, "state_dict_no_save_keys", None)
|
||||
if state_dict_no_save_keys is None:
|
||||
continue
|
||||
|
||||
# check the keys are in the original state_dict
|
||||
for k in state_dict_no_save_keys:
|
||||
self.assertIn(k, model.state_dict())
|
||||
|
||||
# check that certain keys didn't get saved with the model
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME)
|
||||
state_dict_saved = torch.load(output_model_file)
|
||||
for k in state_dict_no_save_keys:
|
||||
self.assertNotIn(k, state_dict_saved)
|
||||
|
||||
@unittest.skip("can't be implemented for FSMT due to dual vocab.")
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user