From 31d81943c994d11a079223809cfca89bfaaee363 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 16 Jul 2025 04:44:40 -0400 Subject: [PATCH] [Core] [Offloading] Fix saving offloaded submodules (#39280) * fix counting meta tensors, fix onloading meta tensors Signed-off-by: Kyle Sayers * remove unrelated fix Signed-off-by: Kyle Sayers * remove unrelated change Signed-off-by: Kyle Sayers * add clarifying comment Signed-off-by: Kyle Sayers * add test_save_offloaded_model_with_direct_params Signed-off-by: Kyle Sayers * fix merge conflict, add decorators Signed-off-by: Kyle Sayers --------- Signed-off-by: Kyle Sayers --- src/transformers/modeling_utils.py | 13 +++++---- tests/utils/test_modeling_utils.py | 45 ++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3202ef47c1..6070bd56bf 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3900,12 +3900,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi # init state_dict for this shard shard_state_dict = dict.fromkeys(shard, "") for module_name in shard: - # skip to collect this weight again - if shard_state_dict.get(module_name) != "": - continue - module = module_map[module_name] - # update state dict with onloaded parameters - shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict) + # note that get_state_dict_from_offload can update with meta tensors + # if both a parent module and its descendant are offloaded + tensor = shard_state_dict[module_name] + if tensor == "" or (isinstance(tensor, torch.Tensor) and tensor.device.type == "meta"): + # update state dict with onloaded parameters + module = module_map[module_name] + shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict) # assign shard to be the completed state dict shard = shard_state_dict diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 6754e22912..4da0aefce9 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -158,6 +158,38 @@ if is_torch_available(): def forward(self, x): return self.linear2(self.linear(self.base(x))) + class ModelWithDirectParam(PreTrainedModel): + base_model_prefix = "base" + config_class = PretrainedConfig + + def _init_weights(self, module): + pass + + def __init__(self, config): + super().__init__(config) + # direct params and submodules is helpful for testing offloading logic + self.weight = nn.Parameter(torch.rand((5, 5))) + self.base = BaseModel(config) + + def forward(self, x): + return self.base(x @ self.weight.T) + + class ModelWithDirectParamSubmodule(PreTrainedModel): + base_model_prefix = "base" + config_class = PretrainedConfig + + def _init_weights(self, module): + pass + + def __init__(self, config): + super().__init__(config) + self.submodule = ModelWithDirectParam(config) + # needed so model can have at least one module on accelerator + self.linear = nn.Linear(5, 5) + + def forward(self, x): + return self.linear(self.submodule(x)) + class ModelWithHeadAndTiedWeights(PreTrainedModel): base_model_prefix = "base" config_class = PretrainedConfig @@ -1187,6 +1219,19 @@ class ModelUtilsTest(TestCasePlus): torch.testing.assert_close(output, presaved_output, rtol=1e-4, atol=1e-4) torch.testing.assert_close(presaved_output, postsaved_output) + @require_accelerate + @mark.accelerate_tests + @require_torch_accelerator + def test_save_offloaded_model_with_direct_params(self): + from accelerate import dispatch_model + + device_map = {"submodule": "cpu", "linear": f"{torch_device}:0"} + model = ModelWithDirectParamSubmodule(PretrainedConfig()) + dispatch_model(model, device_map) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + @require_accelerate @mark.accelerate_tests @require_torch_accelerator