Use mmap option to load_state_dict (#28331)
Use mmap option to load_state_dict (#28331)
This commit is contained in:
@@ -101,7 +101,7 @@ if is_torch_available():
|
||||
from torch import nn
|
||||
|
||||
from transformers import MODEL_MAPPING, AdaptiveEmbedding
|
||||
from transformers.modeling_utils import no_init_weights
|
||||
from transformers.modeling_utils import load_state_dict, no_init_weights
|
||||
from transformers.pytorch_utils import id_tensor_storage
|
||||
|
||||
|
||||
@@ -536,6 +536,54 @@ class ModelTesterMixin:
|
||||
).item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_torch_save_load(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if config.__class__ not in MODEL_MAPPING:
|
||||
return
|
||||
base_class = MODEL_MAPPING[config.__class__]
|
||||
|
||||
if isinstance(base_class, tuple):
|
||||
base_class = base_class[0]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
# make a copy of model class to not break future tests
|
||||
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
|
||||
class CopyClass(base_class):
|
||||
pass
|
||||
|
||||
base_class_copy = CopyClass
|
||||
|
||||
# make sure that all keys are expected for test
|
||||
base_class_copy._keys_to_ignore_on_load_missing = []
|
||||
|
||||
# make init deterministic, but make sure that
|
||||
# non-initialized weights throw errors nevertheless
|
||||
base_class_copy._init_weights = _mock_init_weights
|
||||
base_class_copy.init_weights = _mock_all_init_weights
|
||||
|
||||
model = model_class(config)
|
||||
state_dict = model.state_dict()
|
||||
|
||||
def check_equal(loaded):
|
||||
for key in state_dict.keys():
|
||||
max_diff = torch.max(
|
||||
state_dict()[key] ^ loaded[key]
|
||||
if isinstance(state_dict[key], torch.BoolTensor)
|
||||
else torch.abs(state_dict[key] - loaded[key])
|
||||
).item()
|
||||
self.assertLessEqual(max_diff, 1e-6, msg=f"{key} not identical")
|
||||
|
||||
# check that certain keys didn't get saved with the model
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_checkpoint_path = os.path.join(tmpdirname, "pytorch_model.bin")
|
||||
torch.save(state_dict, pt_checkpoint_path, _use_new_zipfile_serialization=True)
|
||||
check_equal(load_state_dict(pt_checkpoint_path))
|
||||
torch.save(state_dict, pt_checkpoint_path, _use_new_zipfile_serialization=False)
|
||||
check_equal(load_state_dict(pt_checkpoint_path))
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user