Refactor some core stuff (#36539)
* some config changes * update * current state * update * update * updates and cleanup * something that works * fixup * fixes * nits * nit * nits and fix * Update src/transformers/integrations/tensor_parallel.py Co-authored-by: Lysandre Debut <hi@lysand.re> * Update src/transformers/integrations/tensor_parallel.py Co-authored-by: Lysandre Debut <hi@lysand.re> * cleanup * style * safe import * fix * updates * rename stuff an clean * style * small updates * ups * oups * nit * protect imports * update tp * rodfl * arf * turbo nit on init * fix import error * frumble gumbgle * try to fix the import error * should fix the non model test * update keep in float32 * update * fix * nits * fix subvconfigs * test was weird * nit * fix failing test * fix instruct blip * fixes * style * x.com * fix overwrite * ok last bit of failing test --------- Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
@@ -54,17 +54,20 @@ from .integrations.deepspeed import _load_state_dict_into_zero3_model
|
||||
from .integrations.flash_attention import flash_attention_forward
|
||||
from .integrations.flex_attention import flex_attention_forward
|
||||
from .integrations.sdpa_attention import sdpa_attention_forward
|
||||
from .integrations.tensor_parallel import (
|
||||
SUPPORTED_TP_STYLES,
|
||||
shard_and_distribute_module,
|
||||
translate_to_torch_parallel_style,
|
||||
)
|
||||
from .loss.loss_utils import LOSS_MAPPING
|
||||
from .pytorch_utils import ( # noqa: F401
|
||||
Conv1D,
|
||||
apply_chunking_to_forward,
|
||||
distribute_module,
|
||||
find_pruneable_heads_and_indices,
|
||||
id_tensor_storage,
|
||||
prune_conv1d_layer,
|
||||
prune_layer,
|
||||
prune_linear_layer,
|
||||
translate_to_torch_parallel_style,
|
||||
)
|
||||
from .quantizers import AutoHfQuantizer, HfQuantizer
|
||||
from .quantizers.quantizers_utils import get_module_from_name
|
||||
@@ -151,6 +154,7 @@ logger = logging.get_logger(__name__)
|
||||
_init_weights = True
|
||||
_is_quantized = False
|
||||
_is_ds_init_called = False
|
||||
_torch_distributed_available = torch.distributed.is_available()
|
||||
|
||||
|
||||
def is_fsdp_enabled():
|
||||
@@ -181,8 +185,6 @@ else:
|
||||
if is_peft_available():
|
||||
from .utils import find_adapter_config_file
|
||||
|
||||
if is_torch_greater_or_equal("2.5"):
|
||||
from torch.distributed.tensor import DTensor, Shard
|
||||
|
||||
SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel")
|
||||
|
||||
@@ -756,6 +758,40 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
|
||||
setattr(submodule, param_name, new_val)
|
||||
|
||||
|
||||
def fix_tensor_type_and_device(
|
||||
model, param_name, param, dtype=None, keep_in_fp32_modules=None
|
||||
) -> Union[str, torch.dtype]:
|
||||
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
|
||||
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
|
||||
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
|
||||
|
||||
old_param = model
|
||||
if "." in param_name:
|
||||
pre, _ = param_name.rsplit(".", 1)
|
||||
|
||||
old_param = model.get_submodule(pre)
|
||||
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
|
||||
old_param = None
|
||||
|
||||
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
||||
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
|
||||
# in int/uint/bool and not cast them.
|
||||
param_casting_dtype = None
|
||||
is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
|
||||
if param.dtype.is_floating_point and not is_param_float8_e4m3fn:
|
||||
if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(param_name):
|
||||
param_casting_dtype = torch.float32
|
||||
elif dtype is not None:
|
||||
param_casting_dtype = dtype
|
||||
elif old_param is not None:
|
||||
param_casting_dtype = old_param.dtype
|
||||
return old_param is not None and old_param.is_contiguous(), param_casting_dtype
|
||||
else:
|
||||
return False, None
|
||||
|
||||
return
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _load_state_dict_into_meta_model(
|
||||
model: torch.nn.Module,
|
||||
@@ -787,18 +823,12 @@ def _load_state_dict_into_meta_model(
|
||||
It also initialize tensor parallelism for each module if needed.
|
||||
|
||||
"""
|
||||
tensor_device = None
|
||||
tensor_device = "cpu"
|
||||
if device_map is not None and device_map.get("", None) is not None:
|
||||
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
|
||||
if device_map is not None:
|
||||
device_map_regex = "|".join(sorted(device_map.keys(), reverse=True))
|
||||
|
||||
# we need this later to initialize tensor parallelism
|
||||
if device_mesh is not None:
|
||||
full_tp_plan = model.config.base_model_tp_plan
|
||||
for submodule in model.modules():
|
||||
full_tp_plan.update(getattr(submodule, "_tp_plan", {}))
|
||||
|
||||
file_pointer = None
|
||||
bin_state_dict = None
|
||||
if shard_file.endswith(".safetensors"):
|
||||
@@ -818,8 +848,6 @@ def _load_state_dict_into_meta_model(
|
||||
|
||||
is_quantized = hf_quantizer is not None
|
||||
|
||||
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
||||
|
||||
for serialized_param_name, empty_param in state_dict.items():
|
||||
# serialized_param_name is the raw, serialized name
|
||||
# fixed_param_name is the model's equivalent
|
||||
@@ -829,87 +857,37 @@ def _load_state_dict_into_meta_model(
|
||||
continue
|
||||
|
||||
# we need to use serialized_param_name as file pointer is untouched
|
||||
param = (
|
||||
file_pointer.get_slice(serialized_param_name)
|
||||
if shard_file.endswith(".safetensors")
|
||||
else bin_state_dict[serialized_param_name]
|
||||
if shard_file.endswith(".safetensors"):
|
||||
param = file_pointer.get_slice(serialized_param_name)
|
||||
elif shard_file.endswith(".gguf"):
|
||||
param = empty_param # For gguf the dict is actually not empty!
|
||||
else:
|
||||
param = bin_state_dict[serialized_param_name]
|
||||
|
||||
to_contiguous, param_casting_dtype = fix_tensor_type_and_device(
|
||||
model,
|
||||
param_name=fixed_param_name,
|
||||
param=empty_param,
|
||||
dtype=dtype,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
)
|
||||
|
||||
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
|
||||
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
|
||||
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
|
||||
|
||||
old_param = model
|
||||
splits = fixed_param_name.split(".")
|
||||
for split in splits:
|
||||
# We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys.
|
||||
old_param = getattr(old_param, split, None)
|
||||
if old_param is None:
|
||||
break
|
||||
|
||||
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
|
||||
old_param = None
|
||||
|
||||
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
|
||||
# in int/uint/bool and not cast them.
|
||||
param_casting_dtype = None
|
||||
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
|
||||
if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
|
||||
if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(fixed_param_name):
|
||||
param_casting_dtype = torch.float32
|
||||
elif dtype is not None:
|
||||
param_casting_dtype = dtype
|
||||
elif old_param is not None:
|
||||
param_casting_dtype = old_param.dtype
|
||||
|
||||
if device_mesh is not None: # In this case, the param is already on the correct device!
|
||||
module_to_tp, param_type = get_module_from_name(model, fixed_param_name)
|
||||
current_module_plan = None
|
||||
full_tp_plan_ = "|".join(full_tp_plan.keys()).replace("*", "[0-9]+")
|
||||
if plan := re.search(full_tp_plan_, fixed_param_name):
|
||||
match = re.sub("[0-9]+", "*", plan[0])
|
||||
current_module_plan = full_tp_plan[match]
|
||||
|
||||
if current_module_plan is not None:
|
||||
tp_layer = translate_to_torch_parallel_style(current_module_plan)
|
||||
rank = tensor_device
|
||||
row, col = empty_param.shape
|
||||
if "rowwise" == current_module_plan:
|
||||
param = param[:, rank * (col // device_mesh.size()) : (rank + 1) * (col // device_mesh.size())]
|
||||
shard = Shard(1)
|
||||
tp_layer.desired_input_layouts = (Shard(-1),)
|
||||
elif "colwise" == current_module_plan:
|
||||
param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :]
|
||||
shard = Shard(0)
|
||||
else:
|
||||
param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :]
|
||||
shard = Shard(0)
|
||||
if param_casting_dtype is not None:
|
||||
param = param.to(param_casting_dtype)
|
||||
if old_param.is_contiguous():
|
||||
param = param.contiguous()
|
||||
local_parameter = DTensor.from_local(
|
||||
param,
|
||||
device_mesh=device_mesh,
|
||||
placements=[shard] * device_mesh.ndim,
|
||||
)
|
||||
if isinstance(module_to_tp.weight, nn.Parameter):
|
||||
local_parameter = torch.nn.Parameter(local_parameter)
|
||||
module_to_tp.weight = local_parameter
|
||||
input_fn = partial(tp_layer._prepare_input_fn, tp_layer.input_layouts, tp_layer.desired_input_layouts)
|
||||
output_fn = partial(tp_layer._prepare_output_fn, tp_layer.output_layouts, tp_layer.use_local_output)
|
||||
distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn)
|
||||
else:
|
||||
param = param[:]
|
||||
if old_param is not None and old_param.is_contiguous():
|
||||
param = param.contiguous()
|
||||
module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True)
|
||||
|
||||
shard_and_distribute_module(
|
||||
model,
|
||||
param,
|
||||
empty_param,
|
||||
fixed_param_name,
|
||||
param_casting_dtype,
|
||||
to_contiguous,
|
||||
tensor_device, # the rank
|
||||
device_mesh,
|
||||
)
|
||||
else:
|
||||
param = param[:]
|
||||
if param_casting_dtype is not None:
|
||||
param = param.to(param_casting_dtype)
|
||||
if old_param is not None and old_param.is_contiguous():
|
||||
if to_contiguous:
|
||||
param = param.contiguous()
|
||||
|
||||
if device_map is None:
|
||||
@@ -966,6 +944,7 @@ def _load_state_dict_into_meta_model(
|
||||
val_kwargs["requires_grad"] = False
|
||||
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
|
||||
setattr(module, param_type, value)
|
||||
|
||||
if file_pointer is not None:
|
||||
file_pointer.__exit__(None, None, None)
|
||||
|
||||
@@ -1409,7 +1388,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# A tensor parallel plan to be applied to the model when TP is enabled. For
|
||||
# top-level models, this attribute is currently defined in respective model
|
||||
# code. For base models, this attribute comes from
|
||||
# `config.base_model_tp_plan` during `post_init`.
|
||||
# `config.base_model_tp_plan` during `__init__`.
|
||||
# It should identify the layers exactly: if you want to TP model.language_model.layers.fc1
|
||||
# by passing `tp_plan` to the init, it should be {"model.language_model.layers.fc1":"colwise"}
|
||||
# for example.
|
||||
_tp_plan = None
|
||||
|
||||
# A pipeline parallel plan specifying the layers which may not be present
|
||||
@@ -1475,6 +1457,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# when a different component (e.g. language_model) is used.
|
||||
self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)
|
||||
|
||||
self._no_split_modules = self._no_split_modules or []
|
||||
|
||||
def post_init(self):
|
||||
"""
|
||||
A method executed at the end of each Transformer model initialization, to execute code that needs the model's
|
||||
@@ -1482,11 +1466,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
"""
|
||||
self.init_weights()
|
||||
self._backward_compatibility_gradient_checkpointing()
|
||||
|
||||
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
|
||||
if self.base_model is self:
|
||||
self._tp_plan = self.config.base_model_tp_plan
|
||||
self._pp_plan = self.config.base_model_pp_plan
|
||||
|
||||
self._tp_plan = self._tp_plan or self.config.base_model_tp_plan or {}
|
||||
for name, module in self.named_children():
|
||||
if plan := getattr(module, "_tp_plan", None):
|
||||
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.items()})
|
||||
|
||||
if self._tp_plan is not None and is_torch_greater_or_equal("2.3"):
|
||||
for _, v in self._tp_plan.items():
|
||||
if v not in SUPPORTED_TP_STYLES:
|
||||
raise ValueError(
|
||||
f"Unsupported tensor parallel style {v}. Supported styles are {SUPPORTED_TP_STYLES}"
|
||||
)
|
||||
|
||||
def dequantize(self):
|
||||
"""
|
||||
Potentially dequantize the model in case it has been quantized by a quantization method that support
|
||||
@@ -4315,7 +4311,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
if device_mesh is not None and not model.supports_tp_plan:
|
||||
raise NotImplementedError("This model does not have a tensor parallel plan.")
|
||||
if config.base_model_tp_plan is None and config.get_text_config().base_model_tp_plan is None:
|
||||
raise NotImplementedError("This model does not have a tensor parallel plan.")
|
||||
|
||||
# make sure we use the model's config since the __init__ call might have copied it
|
||||
config = model.config
|
||||
@@ -4453,7 +4450,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
model,
|
||||
state_dict,
|
||||
loaded_state_dict_keys, # XXX: rename?
|
||||
resolved_archive_file,
|
||||
resolved_archive_file or gguf_file,
|
||||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
sharded_metadata=sharded_metadata,
|
||||
@@ -4565,7 +4562,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
@staticmethod
|
||||
def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]:
|
||||
"""Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""
|
||||
|
||||
# Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert)
|
||||
# This rename is logged.
|
||||
if key.endswith("LayerNorm.beta"):
|
||||
@@ -4590,6 +4586,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
return key, False
|
||||
|
||||
def rename_key(self, key):
|
||||
"""
|
||||
When we load a LlamaModel from a checkpoint made using LlamaForCausalLM, the keys have an extra
|
||||
prefix, which can be accessed in the `LlamaModel` via the `self.base_model_prefix` attribute.
|
||||
|
||||
But, what if there is an extra layer on top of it? You load a MistralModel from a LlavaForConditionalGeneration?
|
||||
In that what you actually want is to cut whatever is left of the key.
|
||||
"""
|
||||
new_key = key
|
||||
if len(self.base_model_prefix) > 0:
|
||||
if not hasattr(self, self.base_model_prefix) and key.startswith(self.base_model_prefix):
|
||||
@@ -4940,7 +4943,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
unexpected_keys=unexpected_keys,
|
||||
device_mesh=device_mesh,
|
||||
resolved_archive_file=resolved_archive_file,
|
||||
shard_file=resolved_archive_file,
|
||||
weights_only=weights_only,
|
||||
)
|
||||
else:
|
||||
@@ -5019,7 +5022,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
||||
model_to_load,
|
||||
state_dict,
|
||||
start_prefix,
|
||||
prefix,
|
||||
expected_keys,
|
||||
device_map=device_map,
|
||||
offload_folder=offload_folder,
|
||||
@@ -5898,10 +5901,21 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
|
||||
|
||||
for param_name, device in accelerator_device_map.items():
|
||||
try:
|
||||
param = model.get_parameter(param_name)
|
||||
param = getattr(model, param_name)
|
||||
except AttributeError:
|
||||
param = model.get_buffer(param_name)
|
||||
parameter_count[device] += int(math.prod(param.shape) * allocation_factor)
|
||||
if "." in param_name:
|
||||
param_name, param_type = param_name.rsplit(".", 1)
|
||||
param = getattr(model.get_submodule(param_name), param_type)
|
||||
else:
|
||||
param = model.get_buffer(param_name)
|
||||
|
||||
param_size = int(math.prod(param.shape) * allocation_factor)
|
||||
|
||||
if _torch_distributed_available and torch.distributed.is_initialized():
|
||||
generic_name = re.sub(r"\d+", "*", param_name)
|
||||
param_size //= torch.distributed.get_world_size() if not model._tp_plan.get(generic_name, False) else 1
|
||||
|
||||
parameter_count[device] += param_size
|
||||
|
||||
dtype = dtype if dtype is not None else torch.float32
|
||||
|
||||
|
||||
Reference in New Issue
Block a user