Fix couples of issues from #36335 (#36453)

* fix

* style

* better allocation

* fix

* fix

* style

* revert disk

* exit

* style

* return if nothing to cache

* dtensor guard

* fix regressiion

* fix regression

* fix

* fix
This commit is contained in:
Marc Sun
2025-03-01 07:12:17 +01:00
committed by GitHub
parent 2c5d038f92
commit a40f1ac602

View File

@@ -41,7 +41,6 @@ import torch.distributed.tensor
from huggingface_hub import split_torch_state_dict_into_shards from huggingface_hub import split_torch_state_dict_into_shards
from packaging import version from packaging import version
from torch import Tensor, nn from torch import Tensor, nn
from torch.distributed.tensor import DTensor, Shard
from torch.distributions import constraints from torch.distributions import constraints
from torch.nn import CrossEntropyLoss, Identity from torch.nn import CrossEntropyLoss, Identity
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
@@ -67,7 +66,6 @@ from .pytorch_utils import ( # noqa: F401
translate_to_torch_parallel_style, translate_to_torch_parallel_style,
) )
from .quantizers import AutoHfQuantizer, HfQuantizer from .quantizers import AutoHfQuantizer, HfQuantizer
from .quantizers.quantizers_utils import get_module_from_name
from .safetensors_conversion import auto_conversion from .safetensors_conversion import auto_conversion
from .utils import ( from .utils import (
ACCELERATE_MIN_VERSION, ACCELERATE_MIN_VERSION,
@@ -181,6 +179,9 @@ else:
if is_peft_available(): if is_peft_available():
from .utils import find_adapter_config_file from .utils import find_adapter_config_file
if is_torch_greater_or_equal("2.5"):
from torch.distributed.tensor import DTensor, Shard
SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel") SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel")
TORCH_INIT_FUNCTIONS = { TORCH_INIT_FUNCTIONS = {
@@ -702,7 +703,7 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]
return shared_tensors, identical return shared_tensors, identical
def find_submodule_and_param_name(model, long_key, start_prefix): def find_submodule_and_param_name(model, long_key, start_prefix=""):
""" """
A helper util to find the last sub-module and the param/buffer name. If `start_prefix` is supplied it'll be removed A helper util to find the last sub-module and the param/buffer name. If `start_prefix` is supplied it'll be removed
from the start of the key from the start of the key
@@ -767,7 +768,6 @@ def _load_state_dict_into_meta_model(
is_safetensors=False, is_safetensors=False,
keep_in_fp32_modules=None, keep_in_fp32_modules=None,
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys
device_mesh=None, device_mesh=None,
shard_file=None, shard_file=None,
): ):
@@ -786,12 +786,7 @@ def _load_state_dict_into_meta_model(
if device_map is not None and device_map.get("", None) is not None: if device_map is not None and device_map.get("", None) is not None:
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""] tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
with safe_open(shard_file, framework="pt", device=tensor_device) as file_pointer: device_map_regex = "|".join(sorted(device_map.keys(), reverse=True))
error_msgs = []
is_quantized = hf_quantizer is not None
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
# we need this later to initialize tensor parallelism # we need this later to initialize tensor parallelism
if device_mesh is not None: if device_mesh is not None:
@@ -799,24 +794,42 @@ def _load_state_dict_into_meta_model(
for submodule in model.modules(): for submodule in model.modules():
full_tp_plan.update(getattr(submodule, "_tp_plan", {})) full_tp_plan.update(getattr(submodule, "_tp_plan", {}))
for serialized_param_name, empty_param in state_dict.items(): file_pointer = None
# param_name is the raw, serialized name bin_state_dict = None
# new_param_name is the model's equivalent if shard_file.endswith(".safetensors"):
module_name, _ = model.rename_key(serialized_param_name) file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
if module_name not in expected_keys: else:
continue bin_state_dict = load_state_dict(shard_file, map_location="cpu")
layer, param_type = module_name.rsplit(".", 1)
# param name needs to stay untouched as it's in the file error_msgs = []
param = file_pointer.get_slice(serialized_param_name)
is_quantized = hf_quantizer is not None
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
for serialized_param_name, empty_param in state_dict.items():
# serialized_param_name is the raw, serialized name
# fixed_param_name is the model's equivalent
fixed_param_name, _ = model.rename_key(serialized_param_name)
if fixed_param_name not in expected_keys:
continue
# we need to use serialized_param_name as file pointer is untouched
param = (
file_pointer.get_slice(serialized_param_name)
if shard_file.endswith(".safetensors")
else bin_state_dict[serialized_param_name]
)
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
# in int/uint/bool and not cast them. # in int/uint/bool and not cast them.
param_casting_dtype = None param_casting_dtype = None
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
if dtype is not None and empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: if dtype is not None and empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
if ( if (
keep_in_fp32_modules is not None keep_in_fp32_modules is not None
and keep_in_fp32_modules.search(module_name) and keep_in_fp32_modules.search(fixed_param_name)
and dtype == torch.float16 and dtype == torch.float16
): ):
param_casting_dtype = torch.float32 param_casting_dtype = torch.float32
@@ -824,15 +837,10 @@ def _load_state_dict_into_meta_model(
param_casting_dtype = dtype param_casting_dtype = dtype
if device_mesh is not None: # In this case, the param is already on the correct device! if device_mesh is not None: # In this case, the param is already on the correct device!
try: module_to_tp, param_type = find_submodule_and_param_name(model, fixed_param_name)
module_to_tp: torch.nn.Module = model.get_submodule(layer)
except Exception:
raise ValueError(
"The config tp plan is wrong because the layer is not a liner layer, nor an embedding"
)
current_module_plan = None current_module_plan = None
full_tp_plan_ = "|".join(full_tp_plan.keys()).replace("*", "[0-9]+") full_tp_plan_ = "|".join(full_tp_plan.keys()).replace("*", "[0-9]+")
if plan := re.search(full_tp_plan_, module_name): if plan := re.search(full_tp_plan_, fixed_param_name):
match = re.sub("[0-9]+", "*", plan[0]) match = re.sub("[0-9]+", "*", plan[0])
current_module_plan = full_tp_plan[match] current_module_plan = full_tp_plan[match]
@@ -860,63 +868,61 @@ def _load_state_dict_into_meta_model(
if isinstance(module_to_tp.weight, nn.Parameter): if isinstance(module_to_tp.weight, nn.Parameter):
local_parameter = torch.nn.Parameter(local_parameter) local_parameter = torch.nn.Parameter(local_parameter)
module_to_tp.weight = local_parameter module_to_tp.weight = local_parameter
input_fn = partial( input_fn = partial(tp_layer._prepare_input_fn, tp_layer.input_layouts, tp_layer.desired_input_layouts)
tp_layer._prepare_input_fn, tp_layer.input_layouts, tp_layer.desired_input_layouts output_fn = partial(tp_layer._prepare_output_fn, tp_layer.output_layouts, tp_layer.use_local_output)
)
output_fn = partial(
tp_layer._prepare_output_fn, tp_layer.output_layouts, tp_layer.use_local_output
)
distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn) distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn)
else: else:
module_to_tp.load_state_dict({param_type: param[:]}, False, True) module_to_tp.load_state_dict({param_type: param[:]}, strict=False, assign=True)
else: else:
if device_map is None: if device_map is None:
param_device = "cpu" param_device = "cpu"
else: else:
module_name = module_name.rsplit(".", 1)[0] module_layer = re.search(device_map_regex, fixed_param_name)
device_map_regex = "|".join(device_map.keys()) if not module_layer:
module_layer = re.search(device_map_regex, module_name) raise ValueError(f"{fixed_param_name} doesn't have any device set.")
if module_name == "" or device_map_regex is None:
raise ValueError(
f"`device_map` is used, but {module_name} doesn't have any device set. {device_map}"
)
else: else:
param_device = device_map[module_layer.group()] param_device = device_map[module_layer.group()]
if param_device == "disk" and not is_safetensors: if param_device == "disk":
offload_index = offload_weight(param[:], module_name, offload_folder, offload_index) if not is_safetensors:
offload_index = offload_weight(param[:], fixed_param_name, offload_folder, offload_index)
elif param_device == "cpu" and state_dict_index is not None: elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param[:], module_name, state_dict_folder, state_dict_index) state_dict_index = offload_weight(param[:], fixed_param_name, state_dict_folder, state_dict_index)
elif ( elif (
not is_quantized not is_quantized
or (not hf_quantizer.requires_parameters_quantization) or (not hf_quantizer.requires_parameters_quantization)
or ( or (
not hf_quantizer.check_quantized_param( not hf_quantizer.check_quantized_param(
model, param, module_name, state_dict, param_device=param_device, device_map=device_map model,
param,
fixed_param_name,
state_dict,
param_device=param_device,
device_map=device_map,
) )
) )
): ):
if is_fsdp_enabled(): if is_fsdp_enabled():
param_device = "cpu" if is_local_dist_rank_0() else "meta" param_device = "cpu" if is_local_dist_rank_0() else "meta"
module = model.get_submodule(layer) module, param_type = find_submodule_and_param_name(model, fixed_param_name)
if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype: if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype:
param = param[:].to(param_casting_dtype) param = param[:].to(param_casting_dtype)
module.load_state_dict( module.load_state_dict(
{param_type: param[:].to(param_device)}, {param_type: param[:].to(param_device)},
False, strict=False,
True, assign=True,
) )
else: else:
hf_quantizer.create_quantized_param( hf_quantizer.create_quantized_param(
model, param[:], module_name, param_device, state_dict, unexpected_keys model, param[:], fixed_param_name, param_device, state_dict, unexpected_keys
) )
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU # For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
# and then cast it to CPU to avoid excessive memory usage on each GPU # and then cast it to CPU to avoid excessive memory usage on each GPU
# in comparison to the sharded model across GPUs. # in comparison to the sharded model across GPUs.
if is_fsdp_enabled() or is_deepspeed_zero3_enabled(): if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
module, tensor_name = get_module_from_name(model, module_name) module, param_type = find_submodule_and_param_name(model, fixed_param_name)
value = getattr(module, tensor_name) value = getattr(module, param_type)
param_to = "cpu" param_to = "cpu"
if is_fsdp_enabled() and not is_local_dist_rank_0(): if is_fsdp_enabled() and not is_local_dist_rank_0():
param_to = "meta" param_to = "meta"
@@ -924,7 +930,9 @@ def _load_state_dict_into_meta_model(
if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params": if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params":
val_kwargs["requires_grad"] = False val_kwargs["requires_grad"] = False
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__) value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
setattr(module, tensor_name, value) setattr(module, param_type, value)
if file_pointer is not None:
file_pointer.__exit__(None, None, None)
return error_msgs, offload_index, state_dict_index return error_msgs, offload_index, state_dict_index
@@ -4966,7 +4974,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
ignore_mismatched_sizes, ignore_mismatched_sizes,
prefix, prefix,
) )
if low_cpu_mem_usage and shard_file.endswith(".safetensors"): if low_cpu_mem_usage:
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized: if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
for key, param in model_to_load.state_dict().items(): for key, param in model_to_load.state_dict().items():
if param.device == torch.device("meta"): if param.device == torch.device("meta"):
@@ -5840,18 +5848,31 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
accelerator_device_map = { accelerator_device_map = {
param: torch.device(device) for param, device in expanded_device_map.items() if device not in ["cpu", "disk"] param: torch.device(device) for param, device in expanded_device_map.items() if device not in ["cpu", "disk"]
} }
if not len(accelerator_device_map):
return
parameter_count = defaultdict(lambda: 0) parameter_count = defaultdict(lambda: 0)
allocation_factor = 1
if torch.distributed.is_initialized() or len(set(accelerator_device_map.values())) >= 2:
allocation_factor = 2
for param_name, device in accelerator_device_map.items(): for param_name, device in accelerator_device_map.items():
try: try:
param = model.get_parameter(param_name) param = model.get_parameter(param_name)
except AttributeError: except AttributeError:
param = model.get_buffer(param_name) param = model.get_buffer(param_name)
parameter_count[device] += int(math.prod(param.shape) * 2) parameter_count[device] += int(math.prod(param.shape) * allocation_factor)
dtype = dtype if dtype is not None else torch.float32 dtype = dtype if dtype is not None else torch.float32
# This will kick off the caching allocator to avoid having to Malloc afterwards # This will kick off the caching allocator to avoid having to Malloc afterwards
for device, param_count in parameter_count.items(): for device, param_count in parameter_count.items():
_ = torch.empty(int(param_count), dtype=dtype, device=device, requires_grad=False) max_memory_device = None
if device.type == "cuda":
max_memory_device = torch.cuda.mem_get_info(device.index)[0]
# allocate only if we have enough memory
if max_memory_device is None or max_memory_device > param_count * dtype_byte_size(dtype):
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix): def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):