Fix serialization for offloaded model (#31727)
* Fix serialization * style * add test
This commit is contained in:
@@ -2518,9 +2518,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
if state_dict is None:
|
if state_dict is None:
|
||||||
# if any model parameters are offloaded to the disk, make module map
|
# if any model parameters are offloaded, make module map
|
||||||
if hasattr(self, "hf_device_map") and (
|
if (
|
||||||
"cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values()
|
hasattr(self, "hf_device_map")
|
||||||
|
and len(set(self.hf_device_map.values())) > 1
|
||||||
|
and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
|
||||||
):
|
):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)"
|
"Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)"
|
||||||
@@ -2532,7 +2534,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
for key in module_state_dict:
|
for key in module_state_dict:
|
||||||
module_map[name + f".{key}"] = module
|
module_map[name + f".{key}"] = module
|
||||||
|
|
||||||
state_dict = model_to_save.state_dict()
|
state_dict = model_to_save.state_dict()
|
||||||
|
|
||||||
# Translate state_dict from smp to hf if saving with smp >= 1.10
|
# Translate state_dict from smp to hf if saving with smp >= 1.10
|
||||||
@@ -2655,7 +2656,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
and reg.fullmatch(filename_no_suffix) is not None
|
and reg.fullmatch(filename_no_suffix) is not None
|
||||||
):
|
):
|
||||||
os.remove(full_filename)
|
os.remove(full_filename)
|
||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
for shard_file, tensors in state_dict_split.filename_to_tensors.items():
|
for shard_file, tensors in state_dict_split.filename_to_tensors.items():
|
||||||
shard = {tensor: state_dict[tensor] for tensor in tensors}
|
shard = {tensor: state_dict[tensor] for tensor in tensors}
|
||||||
@@ -2667,15 +2667,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
f"Please upgrade accelerate with `pip install -U accelerate`"
|
f"Please upgrade accelerate with `pip install -U accelerate`"
|
||||||
)
|
)
|
||||||
# init state_dict for this shard
|
# init state_dict for this shard
|
||||||
state_dict = {name: "" for name in shard}
|
shard_state_dict = {name: "" for name in shard}
|
||||||
for module_name in shard:
|
for module_name in shard:
|
||||||
module = module_map[module_name]
|
module = module_map[module_name]
|
||||||
# update state dict with onloaded parameters
|
# update state dict with onloaded parameters
|
||||||
state_dict = get_state_dict_from_offload(module, module_name, state_dict)
|
shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)
|
||||||
|
|
||||||
# assign shard to be the completed state dict
|
# assign shard to be the completed state dict
|
||||||
shard = state_dict
|
shard = shard_state_dict
|
||||||
del state_dict
|
del shard_state_dict
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
if safe_serialization:
|
if safe_serialization:
|
||||||
|
|||||||
@@ -1065,6 +1065,23 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
# This check we did call the fake head request
|
# This check we did call the fake head request
|
||||||
mock_head.assert_called()
|
mock_head.assert_called()
|
||||||
|
|
||||||
|
@require_accelerate
|
||||||
|
@mark.accelerate_tests
|
||||||
|
def test_save_model_with_device_map_cpu(self):
|
||||||
|
model_id = "hf-internal-testing/tiny-random-gpt2"
|
||||||
|
inputs = torch.tensor([[1, 2, 3]])
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu")
|
||||||
|
output = model(inputs)[0]
|
||||||
|
model.save_pretrained(
|
||||||
|
tmp_dir, max_shard_size="200KB"
|
||||||
|
) # model is 1.6MB, max shard size is allocated to cpu by default
|
||||||
|
saved_model = AutoModelForCausalLM.from_pretrained(tmp_dir, device_map="cpu")
|
||||||
|
saved_model_output = saved_model(inputs)[0]
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(output, saved_model_output))
|
||||||
|
|
||||||
@require_accelerate
|
@require_accelerate
|
||||||
@mark.accelerate_tests
|
@mark.accelerate_tests
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
@@ -1083,9 +1100,9 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
|
|
||||||
# check_models_equal requires onloaded tensors
|
# check_models_equal requires onloaded tensors
|
||||||
model_id = "hf-internal-testing/tiny-random-gpt2"
|
model_id = "hf-internal-testing/tiny-random-gpt2"
|
||||||
onloaded_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu")
|
onloaded_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu").to(f"{torch_device}:0")
|
||||||
inputs = torch.tensor([[1, 2, 3]]).to(f"{torch_device}:0")
|
inputs = torch.tensor([[1, 2, 3]]).to(f"{torch_device}:0")
|
||||||
cpu_output = onloaded_model(inputs)[0]
|
output = onloaded_model(inputs)[0]
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
offload_folder = os.path.join(tmp_dir, "offload")
|
offload_folder = os.path.join(tmp_dir, "offload")
|
||||||
@@ -1099,7 +1116,7 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
saved_model = AutoModelForCausalLM.from_pretrained(tmp_dir, device_map=device_map)
|
saved_model = AutoModelForCausalLM.from_pretrained(tmp_dir, device_map=device_map)
|
||||||
postsaved_output = saved_model(inputs)[0]
|
postsaved_output = saved_model(inputs)[0]
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(cpu_output, presaved_output, atol=1e-4))
|
self.assertTrue(torch.allclose(output, presaved_output, atol=1e-4))
|
||||||
self.assertTrue(torch.allclose(presaved_output, postsaved_output))
|
self.assertTrue(torch.allclose(presaved_output, postsaved_output))
|
||||||
|
|
||||||
@require_safetensors
|
@require_safetensors
|
||||||
|
|||||||
Reference in New Issue
Block a user