[Core] [Offloading] Fix saving offloaded submodules (#39280)
* fix counting meta tensors, fix onloading meta tensors Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove unrelated fix Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove unrelated change Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add clarifying comment Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add test_save_offloaded_model_with_direct_params Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * fix merge conflict, add decorators Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
@@ -158,6 +158,38 @@ if is_torch_available():
|
||||
def forward(self, x):
|
||||
return self.linear2(self.linear(self.base(x)))
|
||||
|
||||
class ModelWithDirectParam(PreTrainedModel):
|
||||
base_model_prefix = "base"
|
||||
config_class = PretrainedConfig
|
||||
|
||||
def _init_weights(self, module):
|
||||
pass
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
# direct params and submodules is helpful for testing offloading logic
|
||||
self.weight = nn.Parameter(torch.rand((5, 5)))
|
||||
self.base = BaseModel(config)
|
||||
|
||||
def forward(self, x):
|
||||
return self.base(x @ self.weight.T)
|
||||
|
||||
class ModelWithDirectParamSubmodule(PreTrainedModel):
|
||||
base_model_prefix = "base"
|
||||
config_class = PretrainedConfig
|
||||
|
||||
def _init_weights(self, module):
|
||||
pass
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.submodule = ModelWithDirectParam(config)
|
||||
# needed so model can have at least one module on accelerator
|
||||
self.linear = nn.Linear(5, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(self.submodule(x))
|
||||
|
||||
class ModelWithHeadAndTiedWeights(PreTrainedModel):
|
||||
base_model_prefix = "base"
|
||||
config_class = PretrainedConfig
|
||||
@@ -1187,6 +1219,19 @@ class ModelUtilsTest(TestCasePlus):
|
||||
torch.testing.assert_close(output, presaved_output, rtol=1e-4, atol=1e-4)
|
||||
torch.testing.assert_close(presaved_output, postsaved_output)
|
||||
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
@require_torch_accelerator
|
||||
def test_save_offloaded_model_with_direct_params(self):
|
||||
from accelerate import dispatch_model
|
||||
|
||||
device_map = {"submodule": "cpu", "linear": f"{torch_device}:0"}
|
||||
model = ModelWithDirectParamSubmodule(PretrainedConfig())
|
||||
dispatch_model(model, device_map)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
@require_torch_accelerator
|
||||
|
||||
Reference in New Issue
Block a user