Remove all traces of low_cpu_mem_usage (#38792)
* remove it from all py files * remove it from the doc * remove it from examples * style * remove traces of _fast_init * Update test_peft_integration.py * CIs
This commit is contained in:
@@ -578,87 +578,6 @@ class ModelTesterMixin:
|
||||
f"The following keys are not properly handled by `_init_weights()`:\n{different_weights}",
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
def test_save_load_low_cpu_mem_usage(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
with tempfile.TemporaryDirectory() as saved_model_path:
|
||||
for model_class in self.all_model_classes:
|
||||
model_to_save = model_class(config)
|
||||
model_to_save.save_pretrained(saved_model_path)
|
||||
|
||||
self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)
|
||||
|
||||
@slow
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
def test_save_load_low_cpu_mem_usage_checkpoints(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
with tempfile.TemporaryDirectory() as saved_model_path:
|
||||
for model_class in self.all_model_classes:
|
||||
model_to_save = model_class(config)
|
||||
model_to_save.config.save_pretrained(saved_model_path)
|
||||
torch.save(model_to_save.state_dict(), os.path.join(saved_model_path, "pytorch_model.bin"))
|
||||
|
||||
self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)
|
||||
|
||||
@slow
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
|
||||
with tempfile.TemporaryDirectory() as saved_model_path:
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model_to_save = model_class(config)
|
||||
|
||||
model_to_save.save_pretrained(saved_model_path, safe_serialization=False)
|
||||
self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)
|
||||
|
||||
def _check_save_load_low_cpu_mem_usage(self, model_class, saved_model_path):
|
||||
from accelerate.utils.modeling import named_module_tensors
|
||||
|
||||
# Load the low usage and the normal models.
|
||||
model_low_usage, loading_info = model_class.from_pretrained(
|
||||
saved_model_path,
|
||||
low_cpu_mem_usage=True,
|
||||
output_loading_info=True,
|
||||
)
|
||||
model_non_low_usage = model_class.from_pretrained(saved_model_path)
|
||||
|
||||
# Check that there were no missing keys.
|
||||
self.assertEqual(loading_info["missing_keys"], [])
|
||||
|
||||
# The low_cpu_mem_usage=True causes the model params to be initialized with device=meta, and then
|
||||
# subsequently loaded with the correct values and onto the correct device. We check if there are any
|
||||
# remaining params that were not properly loaded.
|
||||
for name, tensor in named_module_tensors(model_low_usage, recurse=True):
|
||||
self.assertNotEqual(
|
||||
tensor.device,
|
||||
torch.device("meta"),
|
||||
"Tensor '" + name + "' has not been properly loaded and has device=meta.",
|
||||
)
|
||||
|
||||
# Check that the parameters are equal.
|
||||
for p1, p2 in zip(model_low_usage.parameters(), model_non_low_usage.parameters()):
|
||||
self.assertEqual(p1.data.ne(p2.data).sum(), 0)
|
||||
|
||||
# Check that the state dict keys are equal.
|
||||
self.assertEqual(set(model_low_usage.state_dict().keys()), set(model_non_low_usage.state_dict().keys()))
|
||||
|
||||
# Check that the shared tensors are equal.
|
||||
tensor_ptrs1 = collections.defaultdict(list)
|
||||
for name, tensor in model_low_usage.state_dict().items():
|
||||
tensor_ptrs1[id_tensor_storage(tensor)].append(name)
|
||||
tied_params1 = [names for _, names in tensor_ptrs1.items() if len(names) > 1]
|
||||
|
||||
tensor_ptrs2 = collections.defaultdict(list)
|
||||
for name, tensor in model_non_low_usage.state_dict().items():
|
||||
tensor_ptrs2[id_tensor_storage(tensor)].append(name)
|
||||
tied_params2 = [names for _, names in tensor_ptrs2.items() if len(names) > 1]
|
||||
|
||||
self.assertEqual(tied_params1, tied_params2)
|
||||
|
||||
def test_torch_save_load(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if config.__class__ not in MODEL_MAPPING:
|
||||
@@ -4100,7 +4019,6 @@ class ModelTesterMixin:
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation="flash_attention_2",
|
||||
low_cpu_mem_usage=True,
|
||||
load_in_4bit=True,
|
||||
)
|
||||
|
||||
@@ -4173,7 +4091,6 @@ class ModelTesterMixin:
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation="flash_attention_2",
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
@@ -4248,7 +4165,6 @@ class ModelTesterMixin:
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation="flash_attention_2",
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
|
||||
Reference in New Issue
Block a user