* fix * style * better allocation * fix * fix * style * revert disk * exit * style * return if nothing to cache * dtensor guard * fix regressiion * fix regression * fix * fix
This commit is contained in:
@@ -41,7 +41,6 @@ import torch.distributed.tensor
|
|||||||
from huggingface_hub import split_torch_state_dict_into_shards
|
from huggingface_hub import split_torch_state_dict_into_shards
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.distributed.tensor import DTensor, Shard
|
|
||||||
from torch.distributions import constraints
|
from torch.distributions import constraints
|
||||||
from torch.nn import CrossEntropyLoss, Identity
|
from torch.nn import CrossEntropyLoss, Identity
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
@@ -67,7 +66,6 @@ from .pytorch_utils import ( # noqa: F401
|
|||||||
translate_to_torch_parallel_style,
|
translate_to_torch_parallel_style,
|
||||||
)
|
)
|
||||||
from .quantizers import AutoHfQuantizer, HfQuantizer
|
from .quantizers import AutoHfQuantizer, HfQuantizer
|
||||||
from .quantizers.quantizers_utils import get_module_from_name
|
|
||||||
from .safetensors_conversion import auto_conversion
|
from .safetensors_conversion import auto_conversion
|
||||||
from .utils import (
|
from .utils import (
|
||||||
ACCELERATE_MIN_VERSION,
|
ACCELERATE_MIN_VERSION,
|
||||||
@@ -181,6 +179,9 @@ else:
|
|||||||
if is_peft_available():
|
if is_peft_available():
|
||||||
from .utils import find_adapter_config_file
|
from .utils import find_adapter_config_file
|
||||||
|
|
||||||
|
if is_torch_greater_or_equal("2.5"):
|
||||||
|
from torch.distributed.tensor import DTensor, Shard
|
||||||
|
|
||||||
SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel")
|
SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel")
|
||||||
|
|
||||||
TORCH_INIT_FUNCTIONS = {
|
TORCH_INIT_FUNCTIONS = {
|
||||||
@@ -702,7 +703,7 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]
|
|||||||
return shared_tensors, identical
|
return shared_tensors, identical
|
||||||
|
|
||||||
|
|
||||||
def find_submodule_and_param_name(model, long_key, start_prefix):
|
def find_submodule_and_param_name(model, long_key, start_prefix=""):
|
||||||
"""
|
"""
|
||||||
A helper util to find the last sub-module and the param/buffer name. If `start_prefix` is supplied it'll be removed
|
A helper util to find the last sub-module and the param/buffer name. If `start_prefix` is supplied it'll be removed
|
||||||
from the start of the key
|
from the start of the key
|
||||||
@@ -767,7 +768,6 @@ def _load_state_dict_into_meta_model(
|
|||||||
is_safetensors=False,
|
is_safetensors=False,
|
||||||
keep_in_fp32_modules=None,
|
keep_in_fp32_modules=None,
|
||||||
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
|
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
|
||||||
pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys
|
|
||||||
device_mesh=None,
|
device_mesh=None,
|
||||||
shard_file=None,
|
shard_file=None,
|
||||||
):
|
):
|
||||||
@@ -786,12 +786,7 @@ def _load_state_dict_into_meta_model(
|
|||||||
if device_map is not None and device_map.get("", None) is not None:
|
if device_map is not None and device_map.get("", None) is not None:
|
||||||
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
|
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
|
||||||
|
|
||||||
with safe_open(shard_file, framework="pt", device=tensor_device) as file_pointer:
|
device_map_regex = "|".join(sorted(device_map.keys(), reverse=True))
|
||||||
error_msgs = []
|
|
||||||
|
|
||||||
is_quantized = hf_quantizer is not None
|
|
||||||
|
|
||||||
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
|
||||||
|
|
||||||
# we need this later to initialize tensor parallelism
|
# we need this later to initialize tensor parallelism
|
||||||
if device_mesh is not None:
|
if device_mesh is not None:
|
||||||
@@ -799,24 +794,42 @@ def _load_state_dict_into_meta_model(
|
|||||||
for submodule in model.modules():
|
for submodule in model.modules():
|
||||||
full_tp_plan.update(getattr(submodule, "_tp_plan", {}))
|
full_tp_plan.update(getattr(submodule, "_tp_plan", {}))
|
||||||
|
|
||||||
for serialized_param_name, empty_param in state_dict.items():
|
file_pointer = None
|
||||||
# param_name is the raw, serialized name
|
bin_state_dict = None
|
||||||
# new_param_name is the model's equivalent
|
if shard_file.endswith(".safetensors"):
|
||||||
module_name, _ = model.rename_key(serialized_param_name)
|
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
|
||||||
if module_name not in expected_keys:
|
else:
|
||||||
continue
|
bin_state_dict = load_state_dict(shard_file, map_location="cpu")
|
||||||
layer, param_type = module_name.rsplit(".", 1)
|
|
||||||
|
|
||||||
# param name needs to stay untouched as it's in the file
|
error_msgs = []
|
||||||
param = file_pointer.get_slice(serialized_param_name)
|
|
||||||
|
is_quantized = hf_quantizer is not None
|
||||||
|
|
||||||
|
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
||||||
|
|
||||||
|
for serialized_param_name, empty_param in state_dict.items():
|
||||||
|
# serialized_param_name is the raw, serialized name
|
||||||
|
# fixed_param_name is the model's equivalent
|
||||||
|
fixed_param_name, _ = model.rename_key(serialized_param_name)
|
||||||
|
|
||||||
|
if fixed_param_name not in expected_keys:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# we need to use serialized_param_name as file pointer is untouched
|
||||||
|
param = (
|
||||||
|
file_pointer.get_slice(serialized_param_name)
|
||||||
|
if shard_file.endswith(".safetensors")
|
||||||
|
else bin_state_dict[serialized_param_name]
|
||||||
|
)
|
||||||
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
|
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
|
||||||
# in int/uint/bool and not cast them.
|
# in int/uint/bool and not cast them.
|
||||||
param_casting_dtype = None
|
param_casting_dtype = None
|
||||||
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
|
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
|
||||||
|
|
||||||
if dtype is not None and empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
|
if dtype is not None and empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
|
||||||
if (
|
if (
|
||||||
keep_in_fp32_modules is not None
|
keep_in_fp32_modules is not None
|
||||||
and keep_in_fp32_modules.search(module_name)
|
and keep_in_fp32_modules.search(fixed_param_name)
|
||||||
and dtype == torch.float16
|
and dtype == torch.float16
|
||||||
):
|
):
|
||||||
param_casting_dtype = torch.float32
|
param_casting_dtype = torch.float32
|
||||||
@@ -824,15 +837,10 @@ def _load_state_dict_into_meta_model(
|
|||||||
param_casting_dtype = dtype
|
param_casting_dtype = dtype
|
||||||
|
|
||||||
if device_mesh is not None: # In this case, the param is already on the correct device!
|
if device_mesh is not None: # In this case, the param is already on the correct device!
|
||||||
try:
|
module_to_tp, param_type = find_submodule_and_param_name(model, fixed_param_name)
|
||||||
module_to_tp: torch.nn.Module = model.get_submodule(layer)
|
|
||||||
except Exception:
|
|
||||||
raise ValueError(
|
|
||||||
"The config tp plan is wrong because the layer is not a liner layer, nor an embedding"
|
|
||||||
)
|
|
||||||
current_module_plan = None
|
current_module_plan = None
|
||||||
full_tp_plan_ = "|".join(full_tp_plan.keys()).replace("*", "[0-9]+")
|
full_tp_plan_ = "|".join(full_tp_plan.keys()).replace("*", "[0-9]+")
|
||||||
if plan := re.search(full_tp_plan_, module_name):
|
if plan := re.search(full_tp_plan_, fixed_param_name):
|
||||||
match = re.sub("[0-9]+", "*", plan[0])
|
match = re.sub("[0-9]+", "*", plan[0])
|
||||||
current_module_plan = full_tp_plan[match]
|
current_module_plan = full_tp_plan[match]
|
||||||
|
|
||||||
@@ -860,63 +868,61 @@ def _load_state_dict_into_meta_model(
|
|||||||
if isinstance(module_to_tp.weight, nn.Parameter):
|
if isinstance(module_to_tp.weight, nn.Parameter):
|
||||||
local_parameter = torch.nn.Parameter(local_parameter)
|
local_parameter = torch.nn.Parameter(local_parameter)
|
||||||
module_to_tp.weight = local_parameter
|
module_to_tp.weight = local_parameter
|
||||||
input_fn = partial(
|
input_fn = partial(tp_layer._prepare_input_fn, tp_layer.input_layouts, tp_layer.desired_input_layouts)
|
||||||
tp_layer._prepare_input_fn, tp_layer.input_layouts, tp_layer.desired_input_layouts
|
output_fn = partial(tp_layer._prepare_output_fn, tp_layer.output_layouts, tp_layer.use_local_output)
|
||||||
)
|
|
||||||
output_fn = partial(
|
|
||||||
tp_layer._prepare_output_fn, tp_layer.output_layouts, tp_layer.use_local_output
|
|
||||||
)
|
|
||||||
distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn)
|
distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn)
|
||||||
else:
|
else:
|
||||||
module_to_tp.load_state_dict({param_type: param[:]}, False, True)
|
module_to_tp.load_state_dict({param_type: param[:]}, strict=False, assign=True)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if device_map is None:
|
if device_map is None:
|
||||||
param_device = "cpu"
|
param_device = "cpu"
|
||||||
else:
|
else:
|
||||||
module_name = module_name.rsplit(".", 1)[0]
|
module_layer = re.search(device_map_regex, fixed_param_name)
|
||||||
device_map_regex = "|".join(device_map.keys())
|
if not module_layer:
|
||||||
module_layer = re.search(device_map_regex, module_name)
|
raise ValueError(f"{fixed_param_name} doesn't have any device set.")
|
||||||
if module_name == "" or device_map_regex is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"`device_map` is used, but {module_name} doesn't have any device set. {device_map}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
param_device = device_map[module_layer.group()]
|
param_device = device_map[module_layer.group()]
|
||||||
|
|
||||||
if param_device == "disk" and not is_safetensors:
|
if param_device == "disk":
|
||||||
offload_index = offload_weight(param[:], module_name, offload_folder, offload_index)
|
if not is_safetensors:
|
||||||
|
offload_index = offload_weight(param[:], fixed_param_name, offload_folder, offload_index)
|
||||||
elif param_device == "cpu" and state_dict_index is not None:
|
elif param_device == "cpu" and state_dict_index is not None:
|
||||||
state_dict_index = offload_weight(param[:], module_name, state_dict_folder, state_dict_index)
|
state_dict_index = offload_weight(param[:], fixed_param_name, state_dict_folder, state_dict_index)
|
||||||
elif (
|
elif (
|
||||||
not is_quantized
|
not is_quantized
|
||||||
or (not hf_quantizer.requires_parameters_quantization)
|
or (not hf_quantizer.requires_parameters_quantization)
|
||||||
or (
|
or (
|
||||||
not hf_quantizer.check_quantized_param(
|
not hf_quantizer.check_quantized_param(
|
||||||
model, param, module_name, state_dict, param_device=param_device, device_map=device_map
|
model,
|
||||||
|
param,
|
||||||
|
fixed_param_name,
|
||||||
|
state_dict,
|
||||||
|
param_device=param_device,
|
||||||
|
device_map=device_map,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
if is_fsdp_enabled():
|
if is_fsdp_enabled():
|
||||||
param_device = "cpu" if is_local_dist_rank_0() else "meta"
|
param_device = "cpu" if is_local_dist_rank_0() else "meta"
|
||||||
module = model.get_submodule(layer)
|
module, param_type = find_submodule_and_param_name(model, fixed_param_name)
|
||||||
if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype:
|
if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype:
|
||||||
param = param[:].to(param_casting_dtype)
|
param = param[:].to(param_casting_dtype)
|
||||||
module.load_state_dict(
|
module.load_state_dict(
|
||||||
{param_type: param[:].to(param_device)},
|
{param_type: param[:].to(param_device)},
|
||||||
False,
|
strict=False,
|
||||||
True,
|
assign=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hf_quantizer.create_quantized_param(
|
hf_quantizer.create_quantized_param(
|
||||||
model, param[:], module_name, param_device, state_dict, unexpected_keys
|
model, param[:], fixed_param_name, param_device, state_dict, unexpected_keys
|
||||||
)
|
)
|
||||||
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
|
# For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
|
||||||
# and then cast it to CPU to avoid excessive memory usage on each GPU
|
# and then cast it to CPU to avoid excessive memory usage on each GPU
|
||||||
# in comparison to the sharded model across GPUs.
|
# in comparison to the sharded model across GPUs.
|
||||||
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
|
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
|
||||||
module, tensor_name = get_module_from_name(model, module_name)
|
module, param_type = find_submodule_and_param_name(model, fixed_param_name)
|
||||||
value = getattr(module, tensor_name)
|
value = getattr(module, param_type)
|
||||||
param_to = "cpu"
|
param_to = "cpu"
|
||||||
if is_fsdp_enabled() and not is_local_dist_rank_0():
|
if is_fsdp_enabled() and not is_local_dist_rank_0():
|
||||||
param_to = "meta"
|
param_to = "meta"
|
||||||
@@ -924,7 +930,9 @@ def _load_state_dict_into_meta_model(
|
|||||||
if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params":
|
if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params":
|
||||||
val_kwargs["requires_grad"] = False
|
val_kwargs["requires_grad"] = False
|
||||||
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
|
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
|
||||||
setattr(module, tensor_name, value)
|
setattr(module, param_type, value)
|
||||||
|
if file_pointer is not None:
|
||||||
|
file_pointer.__exit__(None, None, None)
|
||||||
|
|
||||||
return error_msgs, offload_index, state_dict_index
|
return error_msgs, offload_index, state_dict_index
|
||||||
|
|
||||||
@@ -4966,7 +4974,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
ignore_mismatched_sizes,
|
ignore_mismatched_sizes,
|
||||||
prefix,
|
prefix,
|
||||||
)
|
)
|
||||||
if low_cpu_mem_usage and shard_file.endswith(".safetensors"):
|
if low_cpu_mem_usage:
|
||||||
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
|
if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
|
||||||
for key, param in model_to_load.state_dict().items():
|
for key, param in model_to_load.state_dict().items():
|
||||||
if param.device == torch.device("meta"):
|
if param.device == torch.device("meta"):
|
||||||
@@ -5840,18 +5848,31 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
|
|||||||
accelerator_device_map = {
|
accelerator_device_map = {
|
||||||
param: torch.device(device) for param, device in expanded_device_map.items() if device not in ["cpu", "disk"]
|
param: torch.device(device) for param, device in expanded_device_map.items() if device not in ["cpu", "disk"]
|
||||||
}
|
}
|
||||||
|
if not len(accelerator_device_map):
|
||||||
|
return
|
||||||
|
|
||||||
parameter_count = defaultdict(lambda: 0)
|
parameter_count = defaultdict(lambda: 0)
|
||||||
|
allocation_factor = 1
|
||||||
|
if torch.distributed.is_initialized() or len(set(accelerator_device_map.values())) >= 2:
|
||||||
|
allocation_factor = 2
|
||||||
|
|
||||||
for param_name, device in accelerator_device_map.items():
|
for param_name, device in accelerator_device_map.items():
|
||||||
try:
|
try:
|
||||||
param = model.get_parameter(param_name)
|
param = model.get_parameter(param_name)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
param = model.get_buffer(param_name)
|
param = model.get_buffer(param_name)
|
||||||
parameter_count[device] += int(math.prod(param.shape) * 2)
|
parameter_count[device] += int(math.prod(param.shape) * allocation_factor)
|
||||||
|
|
||||||
dtype = dtype if dtype is not None else torch.float32
|
dtype = dtype if dtype is not None else torch.float32
|
||||||
|
|
||||||
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
||||||
for device, param_count in parameter_count.items():
|
for device, param_count in parameter_count.items():
|
||||||
_ = torch.empty(int(param_count), dtype=dtype, device=device, requires_grad=False)
|
max_memory_device = None
|
||||||
|
if device.type == "cuda":
|
||||||
|
max_memory_device = torch.cuda.mem_get_info(device.index)[0]
|
||||||
|
# allocate only if we have enough memory
|
||||||
|
if max_memory_device is None or max_memory_device > param_count * dtype_byte_size(dtype):
|
||||||
|
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
|
||||||
|
|
||||||
|
|
||||||
def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):
|
def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):
|
||||||
|
|||||||
Reference in New Issue
Block a user