Adding _tie_weights() to prediction heads to support low_cpu_mem_usage=True (#29024)
* Adding _tie_weights() to prediction heads to support low_cpu_mem_usage=True * Testing for the non-safe-tensors case, since the default is safe-tensors already * Running fixup/fix-copies * Adding accelerate annotations to tests
This commit is contained in:
@@ -437,6 +437,88 @@ class ModelTesterMixin:
|
||||
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
@slow
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
def test_save_load_low_cpu_mem_usage(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
with tempfile.TemporaryDirectory() as saved_model_path:
|
||||
for model_class in self.all_model_classes:
|
||||
model_to_save = model_class(config)
|
||||
model_to_save.save_pretrained(saved_model_path)
|
||||
|
||||
self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)
|
||||
|
||||
@slow
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
def test_save_load_low_cpu_mem_usage_checkpoints(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
with tempfile.TemporaryDirectory() as saved_model_path:
|
||||
for model_class in self.all_model_classes:
|
||||
model_to_save = model_class(config)
|
||||
model_to_save.config.save_pretrained(saved_model_path)
|
||||
torch.save(model_to_save.state_dict(), os.path.join(saved_model_path, "pytorch_model.bin"))
|
||||
|
||||
self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)
|
||||
|
||||
@slow
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
|
||||
with tempfile.TemporaryDirectory() as saved_model_path:
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model_to_save = model_class(config)
|
||||
|
||||
model_to_save.save_pretrained(saved_model_path, safe_serialization=False)
|
||||
self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)
|
||||
|
||||
def _check_save_load_low_cpu_mem_usage(self, model_class, saved_model_path):
|
||||
# Load the low usage and the normal models.
|
||||
model_low_usage, loading_info = model_class.from_pretrained(
|
||||
saved_model_path,
|
||||
low_cpu_mem_usage=True,
|
||||
output_loading_info=True,
|
||||
)
|
||||
model_non_low_usage = model_class.from_pretrained(saved_model_path)
|
||||
|
||||
# Check that there were no missing keys.
|
||||
self.assertEqual(loading_info["missing_keys"], [])
|
||||
|
||||
# The low_cpu_mem_usage=True causes the model params to be initialized with device=meta, and then
|
||||
# subsequently loaded with the correct values and onto the correct device. We check if there are any
|
||||
# remaining params that were not properly loaded.
|
||||
for name, param in model_low_usage.named_parameters():
|
||||
self.assertNotEqual(
|
||||
param.device,
|
||||
torch.device("meta"),
|
||||
"Parameter '" + name + "' has not been properly loaded and has device=meta.",
|
||||
)
|
||||
|
||||
# Tests moving the model to a device other than meta.
|
||||
model_low_usage.to(torch_device)
|
||||
|
||||
# Check that the parameters are equal.
|
||||
for p1, p2 in zip(model_low_usage.parameters(), model_non_low_usage.parameters()):
|
||||
self.assertEquals(p1.data.ne(p2.data).sum(), 0)
|
||||
|
||||
# Check that the state dict keys are equal.
|
||||
self.assertEqual(set(model_low_usage.state_dict().keys()), set(model_non_low_usage.state_dict().keys()))
|
||||
|
||||
# Check that the shared tensors are equal.
|
||||
tensor_ptrs1 = collections.defaultdict(list)
|
||||
for name, tensor in model_low_usage.state_dict().items():
|
||||
tensor_ptrs1[id_tensor_storage(tensor)].append(name)
|
||||
tied_params1 = [names for _, names in tensor_ptrs1.items() if len(names) > 1]
|
||||
|
||||
tensor_ptrs2 = collections.defaultdict(list)
|
||||
for name, tensor in model_non_low_usage.state_dict().items():
|
||||
tensor_ptrs2[id_tensor_storage(tensor)].append(name)
|
||||
tied_params2 = [names for _, names in tensor_ptrs2.items() if len(names) > 1]
|
||||
|
||||
self.assertEqual(tied_params1, tied_params2)
|
||||
|
||||
def test_fast_init_context_manager(self):
|
||||
# 1. Create a dummy class. Should have buffers as well? To make sure we test __init__
|
||||
class MyClass(PreTrainedModel):
|
||||
|
||||
Reference in New Issue
Block a user