Remove all traces of low_cpu_mem_usage (#38792)

* remove it from all py files

* remove it from the doc

* remove it from examples

* style

* remove traces of _fast_init

* Update test_peft_integration.py

* CIs
This commit is contained in:
Cyril Vallez
2025-06-12 16:39:33 +02:00
committed by GitHub
parent 3542e0b844
commit 4b8ec667e9
76 changed files with 100 additions and 598 deletions

View File

@@ -578,87 +578,6 @@ class ModelTesterMixin:
f"The following keys are not properly handled by `_init_weights()`:\n{different_weights}",
)
@slow
@require_accelerate
@mark.accelerate_tests
def test_save_load_low_cpu_mem_usage(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
with tempfile.TemporaryDirectory() as saved_model_path:
for model_class in self.all_model_classes:
model_to_save = model_class(config)
model_to_save.save_pretrained(saved_model_path)
self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)
@slow
@require_accelerate
@mark.accelerate_tests
def test_save_load_low_cpu_mem_usage_checkpoints(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
with tempfile.TemporaryDirectory() as saved_model_path:
for model_class in self.all_model_classes:
model_to_save = model_class(config)
model_to_save.config.save_pretrained(saved_model_path)
torch.save(model_to_save.state_dict(), os.path.join(saved_model_path, "pytorch_model.bin"))
self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)
@slow
@require_accelerate
@mark.accelerate_tests
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
with tempfile.TemporaryDirectory() as saved_model_path:
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model_to_save = model_class(config)
model_to_save.save_pretrained(saved_model_path, safe_serialization=False)
self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)
def _check_save_load_low_cpu_mem_usage(self, model_class, saved_model_path):
from accelerate.utils.modeling import named_module_tensors
# Load the low usage and the normal models.
model_low_usage, loading_info = model_class.from_pretrained(
saved_model_path,
low_cpu_mem_usage=True,
output_loading_info=True,
)
model_non_low_usage = model_class.from_pretrained(saved_model_path)
# Check that there were no missing keys.
self.assertEqual(loading_info["missing_keys"], [])
# The low_cpu_mem_usage=True causes the model params to be initialized with device=meta, and then
# subsequently loaded with the correct values and onto the correct device. We check if there are any
# remaining params that were not properly loaded.
for name, tensor in named_module_tensors(model_low_usage, recurse=True):
self.assertNotEqual(
tensor.device,
torch.device("meta"),
"Tensor '" + name + "' has not been properly loaded and has device=meta.",
)
# Check that the parameters are equal.
for p1, p2 in zip(model_low_usage.parameters(), model_non_low_usage.parameters()):
self.assertEqual(p1.data.ne(p2.data).sum(), 0)
# Check that the state dict keys are equal.
self.assertEqual(set(model_low_usage.state_dict().keys()), set(model_non_low_usage.state_dict().keys()))
# Check that the shared tensors are equal.
tensor_ptrs1 = collections.defaultdict(list)
for name, tensor in model_low_usage.state_dict().items():
tensor_ptrs1[id_tensor_storage(tensor)].append(name)
tied_params1 = [names for _, names in tensor_ptrs1.items() if len(names) > 1]
tensor_ptrs2 = collections.defaultdict(list)
for name, tensor in model_non_low_usage.state_dict().items():
tensor_ptrs2[id_tensor_storage(tensor)].append(name)
tied_params2 = [names for _, names in tensor_ptrs2.items() if len(names) > 1]
self.assertEqual(tied_params1, tied_params2)
def test_torch_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING:
@@ -4100,7 +4019,6 @@ class ModelTesterMixin:
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
load_in_4bit=True,
)
@@ -4173,7 +4091,6 @@ class ModelTesterMixin:
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
)
.to(torch_device)
.eval()
@@ -4248,7 +4165,6 @@ class ModelTesterMixin:
tmpdirname,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
)
.to(torch_device)
.eval()