Extend save_pretrained to offloaded models (#27412)
* added hidden subset * debugged hidden subset contrastive search * added contrastive search compression * debugged compressed contrastive search * memory reduction for contrastive search * debugged mem red * added low memory option feature * debugged mem optmimization output stack * debugged mem optmimization output stack * debugged low mem * added low mem cache * fixed 2047 tensor view * debugged 2042 past key val inputs * reformatted tensors * changed low mem output * final clean * removed subset hidden csearch * fixed hidden device * fixed hidden device * changed compressor dtype * removed hstate compression * integrated csearch in generate * test csearch integration into generation exit() * fixed csearch kwarg integration with generation * final wrap and added doc * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * added debug print * direct hstate cat * direct hstate cat * direct hstate cat debug * direct hstate cat debug * expanded full hidden state stack * expanded full hidden state stack * matched dims for hstates * matched dims for hstates * logits fix * equality test * equality hidden debug * debug * added prints for debug * added prints for debug * equality check * switched squeeze dim * input format debug * tracing top_k_ids * removed trace * added test context * added jitter * added jitter * added jitter * returned state * rebuilt past key value reconstruction * debugged * cleaned traces * added selection for pkv * changed output to dict * cleaned * cleaned * cleaned up contrastive search test * moved low_memory kwarg * debugged * changed low mem test batch size to 1 * removed output * debugged test input shape * reformatted csearch test * added trace * removed unsqueeze on final forward pass * replaced unsqueeze with view * removed traces * cleaned * debugged model kwargs * removed special models from test * ran make quality * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * refactored * refactored * refactored * make fixup * renamed flag sequential * renamed flag sequential * iterative onloading * black style and test utils * added traces for integrated test * debugged * added traces * make style * removed traces, make style * included suggestions and added test * debugged test * added offload module check and make style * is_accelerate_available and make style * added test decorator * changed test model and config spec * added offload condition * added lazy loading for each shard * debugged * modified sharding * debugged * added traces * removed safe serialization * no index overload; * trace on safe save ptrs * added ptr condition * debugged * debugged ptr * moved module map init * remake shard only for offloaded modules * refactored * debugged * refactored * debugged * cleaned and make style * cleaned and make style * added trace * sparse module map * debugged * removed module map conditional * refactored * debug * debugged * added traces * added shard mem trace * added shard mem trace * removed underlying storage check * refactored * memory leak removal and make style * cleaned * swapped test decs and make style * added mem checks and make style * added free mem warning * implemented some suggestions * moved onloading to accelerate * refactored for accelerate integration * cleaned test * make style * debugged offload map name * cleaned and make style * replaced meta device check for sharding * cleaned and make style * implemented some suggestions * more suggestions * update warning Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * more suggestions * make style * new make style * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -119,6 +119,10 @@ if is_accelerate_available():
|
|||||||
set_module_tensor_to_device,
|
set_module_tensor_to_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
accelerate_version = version.parse(importlib.metadata.version("accelerate"))
|
||||||
|
if accelerate_version >= version.parse("0.31"):
|
||||||
|
from accelerate.utils.modeling import get_state_dict_from_offload
|
||||||
|
|
||||||
if is_safetensors_available():
|
if is_safetensors_available():
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from safetensors.torch import load_file as safe_load_file
|
from safetensors.torch import load_file as safe_load_file
|
||||||
@@ -374,13 +378,12 @@ def shard_checkpoint(
|
|||||||
storage_id = id_tensor_storage(weight)
|
storage_id = id_tensor_storage(weight)
|
||||||
|
|
||||||
# If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block`
|
# If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block`
|
||||||
if storage_id in storage_id_to_block:
|
if storage_id in storage_id_to_block and weight.device != torch.device("meta"):
|
||||||
block_id = storage_id_to_block[storage_id]
|
block_id = storage_id_to_block[storage_id]
|
||||||
sharded_state_dicts[block_id][key] = weight
|
sharded_state_dicts[block_id][key] = weight
|
||||||
continue
|
continue
|
||||||
|
|
||||||
weight_size = weight.numel() * dtype_byte_size(weight.dtype)
|
weight_size = weight.numel() * dtype_byte_size(weight.dtype)
|
||||||
|
|
||||||
# If this weight is going to tip up over the maximal size, we split, but only if we have put at least one
|
# If this weight is going to tip up over the maximal size, we split, but only if we have put at least one
|
||||||
# weight in the current shard.
|
# weight in the current shard.
|
||||||
if last_block_size + weight_size > max_shard_size and len(sharded_state_dicts[-1]) > 0:
|
if last_block_size + weight_size > max_shard_size and len(sharded_state_dicts[-1]) > 0:
|
||||||
@@ -2504,8 +2507,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
current_peft_config = self.peft_config[active_adapter]
|
current_peft_config = self.peft_config[active_adapter]
|
||||||
current_peft_config.save_pretrained(save_directory)
|
current_peft_config.save_pretrained(save_directory)
|
||||||
|
|
||||||
|
# for offloaded modules
|
||||||
|
module_map = {}
|
||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
if state_dict is None:
|
if state_dict is None:
|
||||||
|
# if any model parameters are offloaded to the disk, make module map
|
||||||
|
if hasattr(self, "hf_device_map") and (
|
||||||
|
"cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values()
|
||||||
|
):
|
||||||
|
warnings.warn(
|
||||||
|
"Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)"
|
||||||
|
)
|
||||||
|
for name, module in model_to_save.named_modules():
|
||||||
|
if name == "":
|
||||||
|
continue
|
||||||
|
module_state_dict = module.state_dict()
|
||||||
|
|
||||||
|
for key in module_state_dict:
|
||||||
|
module_map[name + f".{key}"] = module
|
||||||
|
|
||||||
state_dict = model_to_save.state_dict()
|
state_dict = model_to_save.state_dict()
|
||||||
|
|
||||||
# Translate state_dict from smp to hf if saving with smp >= 1.10
|
# Translate state_dict from smp to hf if saving with smp >= 1.10
|
||||||
@@ -2531,12 +2552,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# In the non-tensor case, fall back to the pointer of the object itself
|
# In the non-tensor case, fall back to the pointer of the object itself
|
||||||
ptrs[id(tensor)].append(name)
|
ptrs[id(tensor)].append(name)
|
||||||
|
|
||||||
# These are all the pointers of shared tensors.
|
# These are all the pointers of shared tensors
|
||||||
|
if hasattr(self, "hf_device_map"):
|
||||||
|
# if the model has offloaded parameters, we must check using find_tied_parameters()
|
||||||
|
tied_params = find_tied_parameters(self)
|
||||||
|
if tied_params:
|
||||||
|
tied_names = tied_params[0]
|
||||||
|
shared_ptrs = {
|
||||||
|
ptr: names for ptr, names in ptrs.items() if any(name in tied_names for name in names)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
shared_ptrs = {}
|
||||||
|
else:
|
||||||
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
||||||
error_names = []
|
|
||||||
to_delete_names = set()
|
|
||||||
# Recursively descend to find tied weight keys
|
# Recursively descend to find tied weight keys
|
||||||
_tied_weights_keys = _get_tied_weight_keys(self)
|
_tied_weights_keys = _get_tied_weight_keys(self)
|
||||||
|
error_names = []
|
||||||
|
to_delete_names = set()
|
||||||
for names in shared_ptrs.values():
|
for names in shared_ptrs.values():
|
||||||
# Removing the keys which are declared as known duplicates on
|
# Removing the keys which are declared as known duplicates on
|
||||||
# load. This allows to make sure the name which is kept is consistent.
|
# load. This allows to make sure the name which is kept is consistent.
|
||||||
@@ -2609,6 +2642,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
for shard_file, shard in shards.items():
|
for shard_file, shard in shards.items():
|
||||||
|
# remake shard with onloaded parameters if necessary
|
||||||
|
if module_map:
|
||||||
|
if accelerate_version < version.parse("0.31"):
|
||||||
|
raise ImportError(
|
||||||
|
f"You need accelerate version to be greater or equal than 0.31 to save models with offloaded parameters. Detected version {accelerate_version}. "
|
||||||
|
f"Please upgrade accelerate with `pip install -U accelerate`"
|
||||||
|
)
|
||||||
|
# init state_dict for this shard
|
||||||
|
state_dict = {name: "" for name in shard}
|
||||||
|
for module_name in shard:
|
||||||
|
module = module_map[module_name]
|
||||||
|
# update state dict with onloaded parameters
|
||||||
|
state_dict = get_state_dict_from_offload(module, module_name, state_dict)
|
||||||
|
|
||||||
|
# assign shard to be the completed state dict
|
||||||
|
shard = state_dict
|
||||||
|
del state_dict
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
if safe_serialization:
|
if safe_serialization:
|
||||||
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
||||||
# joyfulness), but for now this enough.
|
# joyfulness), but for now this enough.
|
||||||
|
|||||||
@@ -1056,6 +1056,43 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
# This check we did call the fake head request
|
# This check we did call the fake head request
|
||||||
mock_head.assert_called()
|
mock_head.assert_called()
|
||||||
|
|
||||||
|
@require_accelerate
|
||||||
|
@mark.accelerate_tests
|
||||||
|
@require_torch_accelerator
|
||||||
|
def test_save_offloaded_model(self):
|
||||||
|
device_map = {
|
||||||
|
"transformer.wte": f"{torch_device}:0",
|
||||||
|
"transformer.wpe": f"{torch_device}:0",
|
||||||
|
"transformer.h.0": "cpu",
|
||||||
|
"transformer.h.1": "cpu",
|
||||||
|
"transformer.h.2": "cpu",
|
||||||
|
"transformer.h.3": "disk",
|
||||||
|
"transformer.h.4": "disk",
|
||||||
|
"transformer.ln_f": f"{torch_device}:0",
|
||||||
|
"lm_head": f"{torch_device}:0",
|
||||||
|
}
|
||||||
|
|
||||||
|
# check_models_equal requires onloaded tensors
|
||||||
|
model_id = "hf-internal-testing/tiny-random-gpt2"
|
||||||
|
onloaded_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu")
|
||||||
|
inputs = torch.tensor([[1, 2, 3]]).to(f"{torch_device}:0")
|
||||||
|
cpu_output = onloaded_model(inputs)[0]
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
offload_folder = os.path.join(tmp_dir, "offload")
|
||||||
|
offloaded_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id, device_map=device_map, offload_folder=offload_folder
|
||||||
|
)
|
||||||
|
presaved_output = offloaded_model(inputs)[0]
|
||||||
|
offloaded_model.save_pretrained(
|
||||||
|
tmp_dir, max_shard_size="200KB"
|
||||||
|
) # model is 1.6MB, max shard size is allocated to cpu by default
|
||||||
|
saved_model = AutoModelForCausalLM.from_pretrained(tmp_dir, device_map=device_map)
|
||||||
|
postsaved_output = saved_model(inputs)[0]
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(cpu_output, presaved_output, atol=1e-4))
|
||||||
|
self.assertTrue(torch.allclose(presaved_output, postsaved_output))
|
||||||
|
|
||||||
@require_safetensors
|
@require_safetensors
|
||||||
def test_use_safetensors(self):
|
def test_use_safetensors(self):
|
||||||
# Should not raise anymore
|
# Should not raise anymore
|
||||||
|
|||||||
Reference in New Issue
Block a user