[Core] [Offloading] Enable saving offloaded models with multiple shared tensor groups (#39263)
* fix counting meta tensors, fix onloading meta tensors Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove unrelated fix Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add test Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
@@ -1187,6 +1187,29 @@ class ModelUtilsTest(TestCasePlus):
|
||||
torch.testing.assert_close(output, presaved_output, rtol=1e-4, atol=1e-4)
|
||||
torch.testing.assert_close(presaved_output, postsaved_output)
|
||||
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
@require_torch_accelerator
|
||||
def test_save_offloaded_model_dynamic_tied_weights_keys(self):
|
||||
from accelerate import dispatch_model
|
||||
|
||||
device_map = {"base": f"{torch_device}:0", "linear": "cpu", "linear2": "cpu"}
|
||||
model = ModelWithHead(PretrainedConfig())
|
||||
dispatch_model(model, device_map)
|
||||
|
||||
transform_a = torch.nn.Linear(1, 1, bias=False)
|
||||
transform_a._dynamic_tied_weights_keys = ["weight"]
|
||||
transform_b = torch.nn.Linear(1, 1, bias=False)
|
||||
transform_b._dynamic_tied_weights_keys = ["weight"]
|
||||
|
||||
model.linear.register_module("transform_a", transform_a)
|
||||
model.linear.register_module("transform_b", transform_b)
|
||||
model.linear2.register_module("transform_a", transform_a)
|
||||
model.linear2.register_module("transform_b", transform_b)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
@require_safetensors
|
||||
def test_use_safetensors(self):
|
||||
# Should not raise anymore
|
||||
|
||||
Reference in New Issue
Block a user