From a40f1ac602fe900281722254c52ce3773f28eb0e Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Sat, 1 Mar 2025 07:12:17 +0100 Subject: [PATCH] Fix couples of issues from #36335 (#36453) * fix * style * better allocation * fix * fix * style * revert disk * exit * style * return if nothing to cache * dtensor guard * fix regressiion * fix regression * fix * fix --- src/transformers/modeling_utils.py | 289 ++++++++++++++++------------- 1 file changed, 155 insertions(+), 134 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e602d7ca8c..73328e3af5 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -41,7 +41,6 @@ import torch.distributed.tensor from huggingface_hub import split_torch_state_dict_into_shards from packaging import version from torch import Tensor, nn -from torch.distributed.tensor import DTensor, Shard from torch.distributions import constraints from torch.nn import CrossEntropyLoss, Identity from torch.utils.checkpoint import checkpoint @@ -67,7 +66,6 @@ from .pytorch_utils import ( # noqa: F401 translate_to_torch_parallel_style, ) from .quantizers import AutoHfQuantizer, HfQuantizer -from .quantizers.quantizers_utils import get_module_from_name from .safetensors_conversion import auto_conversion from .utils import ( ACCELERATE_MIN_VERSION, @@ -181,6 +179,9 @@ else: if is_peft_available(): from .utils import find_adapter_config_file +if is_torch_greater_or_equal("2.5"): + from torch.distributed.tensor import DTensor, Shard + SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel") TORCH_INIT_FUNCTIONS = { @@ -702,7 +703,7 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor] return shared_tensors, identical -def find_submodule_and_param_name(model, long_key, start_prefix): +def find_submodule_and_param_name(model, long_key, start_prefix=""): """ A helper util to find the last sub-module and the param/buffer name. If `start_prefix` is supplied it'll be removed from the start of the key @@ -767,7 +768,6 @@ def _load_state_dict_into_meta_model( is_safetensors=False, keep_in_fp32_modules=None, unexpected_keys=None, # passing `unexpected` for cleanup from quantization items - pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys device_mesh=None, shard_file=None, ): @@ -786,145 +786,153 @@ def _load_state_dict_into_meta_model( if device_map is not None and device_map.get("", None) is not None: tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""] - with safe_open(shard_file, framework="pt", device=tensor_device) as file_pointer: - error_msgs = [] + device_map_regex = "|".join(sorted(device_map.keys(), reverse=True)) - is_quantized = hf_quantizer is not None + # we need this later to initialize tensor parallelism + if device_mesh is not None: + full_tp_plan = model.config.base_model_tp_plan + for submodule in model.modules(): + full_tp_plan.update(getattr(submodule, "_tp_plan", {})) - is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") + file_pointer = None + bin_state_dict = None + if shard_file.endswith(".safetensors"): + file_pointer = safe_open(shard_file, framework="pt", device=tensor_device) + else: + bin_state_dict = load_state_dict(shard_file, map_location="cpu") - # we need this later to initialize tensor parallelism - if device_mesh is not None: - full_tp_plan = model.config.base_model_tp_plan - for submodule in model.modules(): - full_tp_plan.update(getattr(submodule, "_tp_plan", {})) + error_msgs = [] - for serialized_param_name, empty_param in state_dict.items(): - # param_name is the raw, serialized name - # new_param_name is the model's equivalent - module_name, _ = model.rename_key(serialized_param_name) - if module_name not in expected_keys: - continue - layer, param_type = module_name.rsplit(".", 1) + is_quantized = hf_quantizer is not None - # param name needs to stay untouched as it's in the file - param = file_pointer.get_slice(serialized_param_name) - # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params - # in int/uint/bool and not cast them. - param_casting_dtype = None - is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn - if dtype is not None and empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: - if ( - keep_in_fp32_modules is not None - and keep_in_fp32_modules.search(module_name) - and dtype == torch.float16 - ): - param_casting_dtype = torch.float32 - else: - param_casting_dtype = dtype + is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") - if device_mesh is not None: # In this case, the param is already on the correct device! - try: - module_to_tp: torch.nn.Module = model.get_submodule(layer) - except Exception: - raise ValueError( - "The config tp plan is wrong because the layer is not a liner layer, nor an embedding" - ) - current_module_plan = None - full_tp_plan_ = "|".join(full_tp_plan.keys()).replace("*", "[0-9]+") - if plan := re.search(full_tp_plan_, module_name): - match = re.sub("[0-9]+", "*", plan[0]) - current_module_plan = full_tp_plan[match] + for serialized_param_name, empty_param in state_dict.items(): + # serialized_param_name is the raw, serialized name + # fixed_param_name is the model's equivalent + fixed_param_name, _ = model.rename_key(serialized_param_name) - if current_module_plan is not None: - tp_layer = translate_to_torch_parallel_style(current_module_plan) - rank = tensor_device - row, col = empty_param.shape - if "rowwise" == current_module_plan: - param = param[:, rank * (col // device_mesh.size()) : (rank + 1) * (col // device_mesh.size())] - shard = Shard(1) - tp_layer.desired_input_layouts = (Shard(-1),) - elif "colwise" == current_module_plan: - param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :] - shard = Shard(0) - else: - param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :] - shard = Shard(0) - if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype: - param = param.to(param_casting_dtype) - local_parameter = DTensor.from_local( - param, - device_mesh=device_mesh, - placements=[shard] * device_mesh.ndim, - ) - if isinstance(module_to_tp.weight, nn.Parameter): - local_parameter = torch.nn.Parameter(local_parameter) - module_to_tp.weight = local_parameter - input_fn = partial( - tp_layer._prepare_input_fn, tp_layer.input_layouts, tp_layer.desired_input_layouts - ) - output_fn = partial( - tp_layer._prepare_output_fn, tp_layer.output_layouts, tp_layer.use_local_output - ) - distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn) - else: - module_to_tp.load_state_dict({param_type: param[:]}, False, True) + if fixed_param_name not in expected_keys: + continue + # we need to use serialized_param_name as file pointer is untouched + param = ( + file_pointer.get_slice(serialized_param_name) + if shard_file.endswith(".safetensors") + else bin_state_dict[serialized_param_name] + ) + # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params + # in int/uint/bool and not cast them. + param_casting_dtype = None + is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn + + if dtype is not None and empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: + if ( + keep_in_fp32_modules is not None + and keep_in_fp32_modules.search(fixed_param_name) + and dtype == torch.float16 + ): + param_casting_dtype = torch.float32 else: - if device_map is None: - param_device = "cpu" - else: - module_name = module_name.rsplit(".", 1)[0] - device_map_regex = "|".join(device_map.keys()) - module_layer = re.search(device_map_regex, module_name) - if module_name == "" or device_map_regex is None: - raise ValueError( - f"`device_map` is used, but {module_name} doesn't have any device set. {device_map}" - ) - else: - param_device = device_map[module_layer.group()] + param_casting_dtype = dtype - if param_device == "disk" and not is_safetensors: - offload_index = offload_weight(param[:], module_name, offload_folder, offload_index) - elif param_device == "cpu" and state_dict_index is not None: - state_dict_index = offload_weight(param[:], module_name, state_dict_folder, state_dict_index) - elif ( - not is_quantized - or (not hf_quantizer.requires_parameters_quantization) - or ( - not hf_quantizer.check_quantized_param( - model, param, module_name, state_dict, param_device=param_device, device_map=device_map - ) - ) - ): - if is_fsdp_enabled(): - param_device = "cpu" if is_local_dist_rank_0() else "meta" - module = model.get_submodule(layer) - if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype: - param = param[:].to(param_casting_dtype) - module.load_state_dict( - {param_type: param[:].to(param_device)}, - False, - True, - ) + if device_mesh is not None: # In this case, the param is already on the correct device! + module_to_tp, param_type = find_submodule_and_param_name(model, fixed_param_name) + current_module_plan = None + full_tp_plan_ = "|".join(full_tp_plan.keys()).replace("*", "[0-9]+") + if plan := re.search(full_tp_plan_, fixed_param_name): + match = re.sub("[0-9]+", "*", plan[0]) + current_module_plan = full_tp_plan[match] + + if current_module_plan is not None: + tp_layer = translate_to_torch_parallel_style(current_module_plan) + rank = tensor_device + row, col = empty_param.shape + if "rowwise" == current_module_plan: + param = param[:, rank * (col // device_mesh.size()) : (rank + 1) * (col // device_mesh.size())] + shard = Shard(1) + tp_layer.desired_input_layouts = (Shard(-1),) + elif "colwise" == current_module_plan: + param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :] + shard = Shard(0) else: - hf_quantizer.create_quantized_param( - model, param[:], module_name, param_device, state_dict, unexpected_keys + param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :] + shard = Shard(0) + if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype: + param = param.to(param_casting_dtype) + local_parameter = DTensor.from_local( + param, + device_mesh=device_mesh, + placements=[shard] * device_mesh.ndim, + ) + if isinstance(module_to_tp.weight, nn.Parameter): + local_parameter = torch.nn.Parameter(local_parameter) + module_to_tp.weight = local_parameter + input_fn = partial(tp_layer._prepare_input_fn, tp_layer.input_layouts, tp_layer.desired_input_layouts) + output_fn = partial(tp_layer._prepare_output_fn, tp_layer.output_layouts, tp_layer.use_local_output) + distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn) + else: + module_to_tp.load_state_dict({param_type: param[:]}, strict=False, assign=True) + + else: + if device_map is None: + param_device = "cpu" + else: + module_layer = re.search(device_map_regex, fixed_param_name) + if not module_layer: + raise ValueError(f"{fixed_param_name} doesn't have any device set.") + else: + param_device = device_map[module_layer.group()] + + if param_device == "disk": + if not is_safetensors: + offload_index = offload_weight(param[:], fixed_param_name, offload_folder, offload_index) + elif param_device == "cpu" and state_dict_index is not None: + state_dict_index = offload_weight(param[:], fixed_param_name, state_dict_folder, state_dict_index) + elif ( + not is_quantized + or (not hf_quantizer.requires_parameters_quantization) + or ( + not hf_quantizer.check_quantized_param( + model, + param, + fixed_param_name, + state_dict, + param_device=param_device, + device_map=device_map, ) - # For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU - # and then cast it to CPU to avoid excessive memory usage on each GPU - # in comparison to the sharded model across GPUs. - if is_fsdp_enabled() or is_deepspeed_zero3_enabled(): - module, tensor_name = get_module_from_name(model, module_name) - value = getattr(module, tensor_name) - param_to = "cpu" - if is_fsdp_enabled() and not is_local_dist_rank_0(): - param_to = "meta" - val_kwargs = {} - if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params": - val_kwargs["requires_grad"] = False - value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__) - setattr(module, tensor_name, value) + ) + ): + if is_fsdp_enabled(): + param_device = "cpu" if is_local_dist_rank_0() else "meta" + module, param_type = find_submodule_and_param_name(model, fixed_param_name) + if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype: + param = param[:].to(param_casting_dtype) + module.load_state_dict( + {param_type: param[:].to(param_device)}, + strict=False, + assign=True, + ) + else: + hf_quantizer.create_quantized_param( + model, param[:], fixed_param_name, param_device, state_dict, unexpected_keys + ) + # For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU + # and then cast it to CPU to avoid excessive memory usage on each GPU + # in comparison to the sharded model across GPUs. + if is_fsdp_enabled() or is_deepspeed_zero3_enabled(): + module, param_type = find_submodule_and_param_name(model, fixed_param_name) + value = getattr(module, param_type) + param_to = "cpu" + if is_fsdp_enabled() and not is_local_dist_rank_0(): + param_to = "meta" + val_kwargs = {} + if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params": + val_kwargs["requires_grad"] = False + value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__) + setattr(module, param_type, value) + if file_pointer is not None: + file_pointer.__exit__(None, None, None) return error_msgs, offload_index, state_dict_index @@ -4966,7 +4974,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ignore_mismatched_sizes, prefix, ) - if low_cpu_mem_usage and shard_file.endswith(".safetensors"): + if low_cpu_mem_usage: if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized: for key, param in model_to_load.state_dict().items(): if param.device == torch.device("meta"): @@ -5840,18 +5848,31 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, accelerator_device_map = { param: torch.device(device) for param, device in expanded_device_map.items() if device not in ["cpu", "disk"] } + if not len(accelerator_device_map): + return + parameter_count = defaultdict(lambda: 0) + allocation_factor = 1 + if torch.distributed.is_initialized() or len(set(accelerator_device_map.values())) >= 2: + allocation_factor = 2 + for param_name, device in accelerator_device_map.items(): try: param = model.get_parameter(param_name) except AttributeError: param = model.get_buffer(param_name) - parameter_count[device] += int(math.prod(param.shape) * 2) + parameter_count[device] += int(math.prod(param.shape) * allocation_factor) dtype = dtype if dtype is not None else torch.float32 + # This will kick off the caching allocator to avoid having to Malloc afterwards for device, param_count in parameter_count.items(): - _ = torch.empty(int(param_count), dtype=dtype, device=device, requires_grad=False) + max_memory_device = None + if device.type == "cuda": + max_memory_device = torch.cuda.mem_get_info(device.index)[0] + # allocate only if we have enough memory + if max_memory_device is None or max_memory_device > param_count * dtype_byte_size(dtype): + _ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False) def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):