Use mmap option to load_state_dict (#28331)
Use mmap option to load_state_dict (#28331)
This commit is contained in:
@@ -30,6 +30,7 @@ from contextlib import contextmanager
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
from zipfile import is_zipfile
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@@ -516,8 +517,16 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
|||||||
map_location = "meta"
|
map_location = "meta"
|
||||||
else:
|
else:
|
||||||
map_location = "cpu"
|
map_location = "cpu"
|
||||||
|
extra_args = {}
|
||||||
return torch.load(checkpoint_file, map_location=map_location, weights_only=True)
|
# mmap can only be used with files serialized with zipfile-based format.
|
||||||
|
if (
|
||||||
|
isinstance(checkpoint_file, str)
|
||||||
|
and map_location != "meta"
|
||||||
|
and version.parse(torch.__version__) >= version.parse("2.1.0")
|
||||||
|
and is_zipfile(checkpoint_file)
|
||||||
|
):
|
||||||
|
extra_args = {"mmap": True}
|
||||||
|
return torch.load(checkpoint_file, map_location=map_location, weights_only=True, **extra_args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
try:
|
try:
|
||||||
with open(checkpoint_file) as f:
|
with open(checkpoint_file) as f:
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ if is_torch_available():
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from transformers import MODEL_MAPPING, AdaptiveEmbedding
|
from transformers import MODEL_MAPPING, AdaptiveEmbedding
|
||||||
from transformers.modeling_utils import no_init_weights
|
from transformers.modeling_utils import load_state_dict, no_init_weights
|
||||||
from transformers.pytorch_utils import id_tensor_storage
|
from transformers.pytorch_utils import id_tensor_storage
|
||||||
|
|
||||||
|
|
||||||
@@ -536,6 +536,54 @@ class ModelTesterMixin:
|
|||||||
).item()
|
).item()
|
||||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
|
def test_torch_save_load(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
if config.__class__ not in MODEL_MAPPING:
|
||||||
|
return
|
||||||
|
base_class = MODEL_MAPPING[config.__class__]
|
||||||
|
|
||||||
|
if isinstance(base_class, tuple):
|
||||||
|
base_class = base_class[0]
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if model_class == base_class:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# make a copy of model class to not break future tests
|
||||||
|
# from https://stackoverflow.com/questions/9541025/how-to-copy-a-python-class
|
||||||
|
class CopyClass(base_class):
|
||||||
|
pass
|
||||||
|
|
||||||
|
base_class_copy = CopyClass
|
||||||
|
|
||||||
|
# make sure that all keys are expected for test
|
||||||
|
base_class_copy._keys_to_ignore_on_load_missing = []
|
||||||
|
|
||||||
|
# make init deterministic, but make sure that
|
||||||
|
# non-initialized weights throw errors nevertheless
|
||||||
|
base_class_copy._init_weights = _mock_init_weights
|
||||||
|
base_class_copy.init_weights = _mock_all_init_weights
|
||||||
|
|
||||||
|
model = model_class(config)
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
|
||||||
|
def check_equal(loaded):
|
||||||
|
for key in state_dict.keys():
|
||||||
|
max_diff = torch.max(
|
||||||
|
state_dict()[key] ^ loaded[key]
|
||||||
|
if isinstance(state_dict[key], torch.BoolTensor)
|
||||||
|
else torch.abs(state_dict[key] - loaded[key])
|
||||||
|
).item()
|
||||||
|
self.assertLessEqual(max_diff, 1e-6, msg=f"{key} not identical")
|
||||||
|
|
||||||
|
# check that certain keys didn't get saved with the model
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
pt_checkpoint_path = os.path.join(tmpdirname, "pytorch_model.bin")
|
||||||
|
torch.save(state_dict, pt_checkpoint_path, _use_new_zipfile_serialization=True)
|
||||||
|
check_equal(load_state_dict(pt_checkpoint_path))
|
||||||
|
torch.save(state_dict, pt_checkpoint_path, _use_new_zipfile_serialization=False)
|
||||||
|
check_equal(load_state_dict(pt_checkpoint_path))
|
||||||
|
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user