Disk offload fix (#17428)
* Fix offload to disk for big models * Add test * Fix test for other models
This commit is contained in:
@@ -597,11 +597,12 @@ def _load_state_dict_into_meta_model(
|
|||||||
raise ValueError(f"{param_name} doesn't have any device set.")
|
raise ValueError(f"{param_name} doesn't have any device set.")
|
||||||
param_device = device_map[module_name]
|
param_device = device_map[module_name]
|
||||||
|
|
||||||
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
|
||||||
if param_device == "disk":
|
if param_device == "disk":
|
||||||
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
|
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
|
||||||
elif param_device == "cpu" and state_dict_index is not None:
|
elif param_device == "cpu" and state_dict_index is not None:
|
||||||
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
|
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
|
||||||
|
else:
|
||||||
|
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
||||||
|
|
||||||
return error_msgs, offload_index, state_dict_index
|
return error_msgs, offload_index, state_dict_index
|
||||||
|
|
||||||
@@ -2216,6 +2217,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
offload_state_dict=False,
|
offload_state_dict=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
):
|
):
|
||||||
|
if device_map is not None and "disk" in device_map.values() and offload_folder is None:
|
||||||
|
raise ValueError(
|
||||||
|
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder` for"
|
||||||
|
" them."
|
||||||
|
)
|
||||||
# Retrieve missing & unexpected_keys
|
# Retrieve missing & unexpected_keys
|
||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
expected_keys = list(model_state_dict.keys())
|
expected_keys = list(model_state_dict.keys())
|
||||||
|
|||||||
@@ -2214,6 +2214,42 @@ class ModelTesterMixin:
|
|||||||
else:
|
else:
|
||||||
self.assertEqual(param.device, torch.device(param_device))
|
self.assertEqual(param.device, torch.device(param_device))
|
||||||
|
|
||||||
|
@require_accelerate
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_disk_offload(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
if isinstance(getattr(config, "num_hidden_layers", None), int) and config.num_hidden_layers < 5:
|
||||||
|
config.num_hidden_layers = 5
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if model_class._no_split_modules is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
model = model_class(config).eval()
|
||||||
|
model = model.to(torch_device)
|
||||||
|
base_output = model(**inputs_dict)
|
||||||
|
|
||||||
|
model_size = compute_module_sizes(model)[""]
|
||||||
|
# We test several splits of sizes to make sure it works.
|
||||||
|
max_size = int(0.4 * model_size)
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.cpu().save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
max_memory = {0: max_size, "cpu": max_size}
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
# This errors out cause it's missing an offload folder
|
||||||
|
new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||||
|
|
||||||
|
new_model = model_class.from_pretrained(
|
||||||
|
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||||
|
new_output = new_model(**inputs_dict)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(base_output[0], new_output[0]))
|
||||||
|
|
||||||
@require_accelerate
|
@require_accelerate
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
def test_cpu_offload(self):
|
def test_cpu_offload(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user