[Core] [Offloading] Fix saving offloaded submodules (#39280)

* fix counting meta tensors, fix onloading meta tensors

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* remove unrelated fix

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* remove unrelated change

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* add clarifying comment

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* add test_save_offloaded_model_with_direct_params

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* fix merge conflict, add decorators

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

---------

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
Kyle Sayers
2025-07-16 04:44:40 -04:00
committed by GitHub
parent add43c4d09
commit 31d81943c9
2 changed files with 52 additions and 6 deletions

View File

@@ -158,6 +158,38 @@ if is_torch_available():
def forward(self, x):
return self.linear2(self.linear(self.base(x)))
class ModelWithDirectParam(PreTrainedModel):
base_model_prefix = "base"
config_class = PretrainedConfig
def _init_weights(self, module):
pass
def __init__(self, config):
super().__init__(config)
# direct params and submodules is helpful for testing offloading logic
self.weight = nn.Parameter(torch.rand((5, 5)))
self.base = BaseModel(config)
def forward(self, x):
return self.base(x @ self.weight.T)
class ModelWithDirectParamSubmodule(PreTrainedModel):
base_model_prefix = "base"
config_class = PretrainedConfig
def _init_weights(self, module):
pass
def __init__(self, config):
super().__init__(config)
self.submodule = ModelWithDirectParam(config)
# needed so model can have at least one module on accelerator
self.linear = nn.Linear(5, 5)
def forward(self, x):
return self.linear(self.submodule(x))
class ModelWithHeadAndTiedWeights(PreTrainedModel):
base_model_prefix = "base"
config_class = PretrainedConfig
@@ -1187,6 +1219,19 @@ class ModelUtilsTest(TestCasePlus):
torch.testing.assert_close(output, presaved_output, rtol=1e-4, atol=1e-4)
torch.testing.assert_close(presaved_output, postsaved_output)
@require_accelerate
@mark.accelerate_tests
@require_torch_accelerator
def test_save_offloaded_model_with_direct_params(self):
from accelerate import dispatch_model
device_map = {"submodule": "cpu", "linear": f"{torch_device}:0"}
model = ModelWithDirectParamSubmodule(PretrainedConfig())
dispatch_model(model, device_map)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
@require_accelerate
@mark.accelerate_tests
@require_torch_accelerator