[Core] [Offloading] Enable saving offloaded models with multiple shared tensor groups (#39263)
* fix counting meta tensors, fix onloading meta tensors Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove unrelated fix Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add test Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
@@ -3835,27 +3835,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
# We're going to remove aliases before saving
|
# We're going to remove aliases before saving
|
||||||
ptrs = collections.defaultdict(list)
|
ptrs = collections.defaultdict(list)
|
||||||
for name, tensor in state_dict.items():
|
for name, tensor in state_dict.items():
|
||||||
# Sometimes in the state_dict we have non-tensor objects.
|
if not isinstance(tensor, torch.Tensor):
|
||||||
# e.g. in bitsandbytes we have some `str` objects in the state_dict
|
# Sometimes in the state_dict we have non-tensor objects.
|
||||||
if isinstance(tensor, torch.Tensor):
|
# e.g. in bitsandbytes we have some `str` objects in the state_dict
|
||||||
ptrs[id_tensor_storage(tensor)].append(name)
|
|
||||||
else:
|
|
||||||
# In the non-tensor case, fall back to the pointer of the object itself
|
# In the non-tensor case, fall back to the pointer of the object itself
|
||||||
ptrs[id(tensor)].append(name)
|
ptrs[id(tensor)].append(name)
|
||||||
|
|
||||||
# These are all the pointers of shared tensors
|
elif tensor.device.type == "meta":
|
||||||
if hasattr(self, "hf_device_map"):
|
# In offloaded cases, there may be meta tensors in the state_dict.
|
||||||
# if the model has offloaded parameters, we must check using find_tied_parameters()
|
# For these cases, key by the pointer of the original tensor object
|
||||||
tied_params = find_tied_parameters(self)
|
# (state_dict tensors are detached and therefore no longer shared)
|
||||||
if tied_params:
|
tensor = self.get_parameter(name)
|
||||||
tied_names = tied_params[0]
|
ptrs[id(tensor)].append(name)
|
||||||
shared_ptrs = {
|
|
||||||
ptr: names for ptr, names in ptrs.items() if any(name in tied_names for name in names)
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
shared_ptrs = {}
|
ptrs[id_tensor_storage(tensor)].append(name)
|
||||||
else:
|
|
||||||
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
||||||
|
|
||||||
# Recursively descend to find tied weight keys
|
# Recursively descend to find tied weight keys
|
||||||
_tied_weights_keys = _get_tied_weight_keys(self)
|
_tied_weights_keys = _get_tied_weight_keys(self)
|
||||||
@@ -3899,7 +3895,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
|
|
||||||
if len(error_names) > 0:
|
if len(error_names) > 0:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"The weights trying to be saved contained shared tensors {error_names} that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.",
|
f"The weights trying to be saved contained shared tensors {error_names} that are mismatching "
|
||||||
|
"the transformers base configuration. Try saving using `safe_serialization=False`, setting the "
|
||||||
|
"`_dynamic_tied_weights_keys` attribute for affected modules, or remove this tensor sharing.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Shard the model if it is too big.
|
# Shard the model if it is too big.
|
||||||
|
|||||||
@@ -1187,6 +1187,29 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
torch.testing.assert_close(output, presaved_output, rtol=1e-4, atol=1e-4)
|
torch.testing.assert_close(output, presaved_output, rtol=1e-4, atol=1e-4)
|
||||||
torch.testing.assert_close(presaved_output, postsaved_output)
|
torch.testing.assert_close(presaved_output, postsaved_output)
|
||||||
|
|
||||||
|
@require_accelerate
|
||||||
|
@mark.accelerate_tests
|
||||||
|
@require_torch_accelerator
|
||||||
|
def test_save_offloaded_model_dynamic_tied_weights_keys(self):
|
||||||
|
from accelerate import dispatch_model
|
||||||
|
|
||||||
|
device_map = {"base": f"{torch_device}:0", "linear": "cpu", "linear2": "cpu"}
|
||||||
|
model = ModelWithHead(PretrainedConfig())
|
||||||
|
dispatch_model(model, device_map)
|
||||||
|
|
||||||
|
transform_a = torch.nn.Linear(1, 1, bias=False)
|
||||||
|
transform_a._dynamic_tied_weights_keys = ["weight"]
|
||||||
|
transform_b = torch.nn.Linear(1, 1, bias=False)
|
||||||
|
transform_b._dynamic_tied_weights_keys = ["weight"]
|
||||||
|
|
||||||
|
model.linear.register_module("transform_a", transform_a)
|
||||||
|
model.linear.register_module("transform_b", transform_b)
|
||||||
|
model.linear2.register_module("transform_a", transform_a)
|
||||||
|
model.linear2.register_module("transform_b", transform_b)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
@require_safetensors
|
@require_safetensors
|
||||||
def test_use_safetensors(self):
|
def test_use_safetensors(self):
|
||||||
# Should not raise anymore
|
# Should not raise anymore
|
||||||
|
|||||||
Reference in New Issue
Block a user