[Core] [Offloading] Enable saving offloaded models with multiple shared tensor groups (#39263)

* fix counting meta tensors, fix onloading meta tensors

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* remove unrelated fix

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* add test

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

---------

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
Kyle Sayers
2025-07-10 12:33:30 -04:00
committed by GitHub
parent df49b399dc
commit bdc8028cb3
2 changed files with 39 additions and 18 deletions

View File

@@ -3835,27 +3835,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# We're going to remove aliases before saving
ptrs = collections.defaultdict(list)
for name, tensor in state_dict.items():
# Sometimes in the state_dict we have non-tensor objects.
# e.g. in bitsandbytes we have some `str` objects in the state_dict
if isinstance(tensor, torch.Tensor):
ptrs[id_tensor_storage(tensor)].append(name)
else:
if not isinstance(tensor, torch.Tensor):
# Sometimes in the state_dict we have non-tensor objects.
# e.g. in bitsandbytes we have some `str` objects in the state_dict
# In the non-tensor case, fall back to the pointer of the object itself
ptrs[id(tensor)].append(name)
# These are all the pointers of shared tensors
if hasattr(self, "hf_device_map"):
# if the model has offloaded parameters, we must check using find_tied_parameters()
tied_params = find_tied_parameters(self)
if tied_params:
tied_names = tied_params[0]
shared_ptrs = {
ptr: names for ptr, names in ptrs.items() if any(name in tied_names for name in names)
}
elif tensor.device.type == "meta":
# In offloaded cases, there may be meta tensors in the state_dict.
# For these cases, key by the pointer of the original tensor object
# (state_dict tensors are detached and therefore no longer shared)
tensor = self.get_parameter(name)
ptrs[id(tensor)].append(name)
else:
shared_ptrs = {}
else:
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
ptrs[id_tensor_storage(tensor)].append(name)
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
# Recursively descend to find tied weight keys
_tied_weights_keys = _get_tied_weight_keys(self)
@@ -3899,7 +3895,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
if len(error_names) > 0:
raise RuntimeError(
f"The weights trying to be saved contained shared tensors {error_names} that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.",
f"The weights trying to be saved contained shared tensors {error_names} that are mismatching "
"the transformers base configuration. Try saving using `safe_serialization=False`, setting the "
"`_dynamic_tied_weights_keys` attribute for affected modules, or remove this tensor sharing.",
)
# Shard the model if it is too big.