Use mmap option to load_state_dict (#28331)

Use mmap option to load_state_dict (#28331)
This commit is contained in:
Weiming Zhao
2024-01-10 00:57:30 -08:00
committed by GitHub
parent 0f2f0c634f
commit 701298d2d3
2 changed files with 60 additions and 3 deletions

View File

@@ -101,7 +101,7 @@ if is_torch_available():
from torch import nn
from transformers import MODEL_MAPPING, AdaptiveEmbedding
from transformers.modeling_utils import no_init_weights
from transformers.modeling_utils import load_state_dict, no_init_weights
from transformers.pytorch_utils import id_tensor_storage
@@ -536,6 +536,54 @@ class ModelTesterMixin:
).item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_torch_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING:
return
base_class = MODEL_MAPPING[config.__class__]
if isinstance(base_class, tuple):
base_class = base_class[0]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
# make a copy of model class to not break future tests
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
class CopyClass(base_class):
pass
base_class_copy = CopyClass
# make sure that all keys are expected for test
base_class_copy._keys_to_ignore_on_load_missing = []
# make init deterministic, but make sure that
# non-initialized weights throw errors nevertheless
base_class_copy._init_weights = _mock_init_weights
base_class_copy.init_weights = _mock_all_init_weights
model = model_class(config)
state_dict = model.state_dict()
def check_equal(loaded):
for key in state_dict.keys():
max_diff = torch.max(
state_dict()[key] ^ loaded[key]
if isinstance(state_dict[key], torch.BoolTensor)
else torch.abs(state_dict[key] - loaded[key])
).item()
self.assertLessEqual(max_diff, 1e-6, msg=f"{key} not identical")
# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pytorch_model.bin")
torch.save(state_dict, pt_checkpoint_path, _use_new_zipfile_serialization=True)
check_equal(load_state_dict(pt_checkpoint_path))
torch.save(state_dict, pt_checkpoint_path, _use_new_zipfile_serialization=False)
check_equal(load_state_dict(pt_checkpoint_path))
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()