Use mmap option to load_state_dict (#28331)

Use mmap option to load_state_dict (#28331)
This commit is contained in:
Weiming Zhao
2024-01-10 00:57:30 -08:00
committed by GitHub
parent 0f2f0c634f
commit 701298d2d3
2 changed files with 60 additions and 3 deletions

View File

@@ -30,6 +30,7 @@ from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial, wraps from functools import partial, wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from zipfile import is_zipfile
import torch import torch
from packaging import version from packaging import version
@@ -516,8 +517,16 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
map_location = "meta" map_location = "meta"
else: else:
map_location = "cpu" map_location = "cpu"
extra_args = {}
return torch.load(checkpoint_file, map_location=map_location, weights_only=True) # mmap can only be used with files serialized with zipfile-based format.
if (
isinstance(checkpoint_file, str)
and map_location != "meta"
and version.parse(torch.__version__) >= version.parse("2.1.0")
and is_zipfile(checkpoint_file)
):
extra_args = {"mmap": True}
return torch.load(checkpoint_file, map_location=map_location, weights_only=True, **extra_args)
except Exception as e: except Exception as e:
try: try:
with open(checkpoint_file) as f: with open(checkpoint_file) as f:

View File

@@ -101,7 +101,7 @@ if is_torch_available():
from torch import nn from torch import nn
from transformers import MODEL_MAPPING, AdaptiveEmbedding from transformers import MODEL_MAPPING, AdaptiveEmbedding
from transformers.modeling_utils import no_init_weights from transformers.modeling_utils import load_state_dict, no_init_weights
from transformers.pytorch_utils import id_tensor_storage from transformers.pytorch_utils import id_tensor_storage
@@ -536,6 +536,54 @@ class ModelTesterMixin:
).item() ).item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_torch_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING:
return
base_class = MODEL_MAPPING[config.__class__]
if isinstance(base_class, tuple):
base_class = base_class[0]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
# make a copy of model class to not break future tests
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
class CopyClass(base_class):
pass
base_class_copy = CopyClass
# make sure that all keys are expected for test
base_class_copy._keys_to_ignore_on_load_missing = []
# make init deterministic, but make sure that
# non-initialized weights throw errors nevertheless
base_class_copy._init_weights = _mock_init_weights
base_class_copy.init_weights = _mock_all_init_weights
model = model_class(config)
state_dict = model.state_dict()
def check_equal(loaded):
for key in state_dict.keys():
max_diff = torch.max(
state_dict()[key] ^ loaded[key]
if isinstance(state_dict[key], torch.BoolTensor)
else torch.abs(state_dict[key] - loaded[key])
).item()
self.assertLessEqual(max_diff, 1e-6, msg=f"{key} not identical")
# check that certain keys didn't get saved with the model
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pytorch_model.bin")
torch.save(state_dict, pt_checkpoint_path, _use_new_zipfile_serialization=True)
check_equal(load_state_dict(pt_checkpoint_path))
torch.save(state_dict, pt_checkpoint_path, _use_new_zipfile_serialization=False)
check_equal(load_state_dict(pt_checkpoint_path))
def test_initialization(self): def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()