diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 79a2c294c6..a2913f2296 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3835,27 +3835,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi # We're going to remove aliases before saving ptrs = collections.defaultdict(list) for name, tensor in state_dict.items(): - # Sometimes in the state_dict we have non-tensor objects. - # e.g. in bitsandbytes we have some `str` objects in the state_dict - if isinstance(tensor, torch.Tensor): - ptrs[id_tensor_storage(tensor)].append(name) - else: + if not isinstance(tensor, torch.Tensor): + # Sometimes in the state_dict we have non-tensor objects. + # e.g. in bitsandbytes we have some `str` objects in the state_dict # In the non-tensor case, fall back to the pointer of the object itself ptrs[id(tensor)].append(name) - # These are all the pointers of shared tensors - if hasattr(self, "hf_device_map"): - # if the model has offloaded parameters, we must check using find_tied_parameters() - tied_params = find_tied_parameters(self) - if tied_params: - tied_names = tied_params[0] - shared_ptrs = { - ptr: names for ptr, names in ptrs.items() if any(name in tied_names for name in names) - } + elif tensor.device.type == "meta": + # In offloaded cases, there may be meta tensors in the state_dict. + # For these cases, key by the pointer of the original tensor object + # (state_dict tensors are detached and therefore no longer shared) + tensor = self.get_parameter(name) + ptrs[id(tensor)].append(name) + else: - shared_ptrs = {} - else: - shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} + ptrs[id_tensor_storage(tensor)].append(name) + + shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} # Recursively descend to find tied weight keys _tied_weights_keys = _get_tied_weight_keys(self) @@ -3899,7 +3895,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi if len(error_names) > 0: raise RuntimeError( - f"The weights trying to be saved contained shared tensors {error_names} that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.", + f"The weights trying to be saved contained shared tensors {error_names} that are mismatching " + "the transformers base configuration. Try saving using `safe_serialization=False`, setting the " + "`_dynamic_tied_weights_keys` attribute for affected modules, or remove this tensor sharing.", ) # Shard the model if it is too big. diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 6da45b1639..6754e22912 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1187,6 +1187,29 @@ class ModelUtilsTest(TestCasePlus): torch.testing.assert_close(output, presaved_output, rtol=1e-4, atol=1e-4) torch.testing.assert_close(presaved_output, postsaved_output) + @require_accelerate + @mark.accelerate_tests + @require_torch_accelerator + def test_save_offloaded_model_dynamic_tied_weights_keys(self): + from accelerate import dispatch_model + + device_map = {"base": f"{torch_device}:0", "linear": "cpu", "linear2": "cpu"} + model = ModelWithHead(PretrainedConfig()) + dispatch_model(model, device_map) + + transform_a = torch.nn.Linear(1, 1, bias=False) + transform_a._dynamic_tied_weights_keys = ["weight"] + transform_b = torch.nn.Linear(1, 1, bias=False) + transform_b._dynamic_tied_weights_keys = ["weight"] + + model.linear.register_module("transform_a", transform_a) + model.linear.register_module("transform_b", transform_b) + model.linear2.register_module("transform_a", transform_a) + model.linear2.register_module("transform_b", transform_b) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + @require_safetensors def test_use_safetensors(self): # Should not raise anymore