diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5bc952e7bd..2918d3a44e 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -787,6 +787,7 @@ def _load_state_dict_into_meta_model( keep_in_fp32_modules=None, unexpected_keys=None, # passing `unexpected` for cleanup from quantization items pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys + device_mesh=None, ): """ This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its @@ -796,6 +797,8 @@ def _load_state_dict_into_meta_model( `start_prefix` is used for models which insert their name into model keys, e.g. `bert` in `bert.pooler.dense.weight` + It also initialize tensor parallelism for each module if needed. + """ # XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model @@ -809,6 +812,12 @@ def _load_state_dict_into_meta_model( is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") + # we need this later to initialize tensor parallelism + if device_mesh is not None: + full_tp_plan = model.config.base_model_tp_plan + for submodule in model.modules(): + full_tp_plan.update(getattr(submodule, "_tp_plan", {})) + for param_name, param in state_dict.items(): if param_name not in expected_keys: continue @@ -912,6 +921,37 @@ def _load_state_dict_into_meta_model( setattr(module, tensor_name, value) # TODO: consider removing used param_parts from state_dict before return + # In this case, let's parallelize the modules! + if device_mesh is not None: + # Immediate parent + split_parent_module_name = param_name.split(".")[:-1] + parent_module_name = ".".join(split_parent_module_name) + parent_module = model + for name in split_parent_module_name: + parent_module = getattr(parent_module, name) + + # Check if we are part of the tp_plan + current_module_plan = None + for param, plan in full_tp_plan.items(): + # "*" are a placeholder for layer indices, so we replace them by "[0-9]+" in the regex pattern + pattern = param.replace("*", "[0-9]+") + if re.search(pattern, parent_module_name): + current_module_plan = plan + break + + # We can only apply the tp_plan after all parameters of the current module have been correctly initialized (e.g. + # if we have bias, we need both `weights` and `bias` of a nn.Linear to be initialized) + process_device = list(device_map.values())[0] + all_module_parameters_initialized = all( + m.device == process_device for m in parent_module.parameters(recurse=False) + ) and all(m.device == process_device for m in parent_module.buffers(recurse=False)) + if current_module_plan is not None and all_module_parameters_initialized: + torch.distributed.tensor.parallel.parallelize_module( + parent_module, + device_mesh=device_mesh, + parallelize_plan=translate_to_torch_parallel_style(current_module_plan), + ) + return error_msgs, offload_index, state_dict_index @@ -3489,12 +3529,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) # We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple - # `device_map` pointing to the correct device. If we don't, torch will use the default device (index 0) for all - # childs processes at parallelization time, resulting in excessive memory usage on device 0 and OOMs. - # And temporarily setting the default device to current process rank result in the following error - # `torch.distributed.DistBackendError: Attempt to perform collective on tensor not on device passed to init_process_group` - tp_device = None + # `device_map` pointing to the correct device + device_mesh = None if tp_plan is not None: + if not is_torch_greater_or_equal("2.5"): + raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.") if not torch.distributed.is_initialized(): raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.") @@ -3506,6 +3545,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # This is the easiest way to dispatch to the current process device device_map = tp_device + # Assuming sharding the model onto the world + world_size = torch.distributed.get_world_size() + device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,)) + if is_fsdp_enabled(): low_cpu_mem_usage = True @@ -3600,7 +3643,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if low_cpu_mem_usage is None: low_cpu_mem_usage = True elif not low_cpu_mem_usage: - raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`") + raise ValueError("Passing along a `device_map` or a `tp_plan` requires `low_cpu_mem_usage=True`") if low_cpu_mem_usage: if is_deepspeed_zero3_enabled(): @@ -3609,7 +3652,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) elif not is_accelerate_available(): raise ImportError( - f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" + f"Using `low_cpu_mem_usage=True`, a `device_map` or a `tp_plan` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" ) # handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation. @@ -4186,6 +4229,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Let's make sure we don't run the init function of buffer modules model = cls(config, *model_args, **model_kwargs) + if device_mesh is not None and not model.supports_tp_plan: + raise NotImplementedError("This model does not have a tensor parallel plan.") + # make sure we use the model's config since the __init__ call might have copied it config = model.config @@ -4336,6 +4382,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix keep_in_fp32_modules=keep_in_fp32_modules, gguf_path=gguf_path, weights_only=weights_only, + device_mesh=device_mesh, ) # make sure token embedding weights are still tied if needed @@ -4370,8 +4417,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) pass - # Dispatch model with hooks on all devices if necessary - if device_map is not None: + # Dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly + # harm performances) + if device_map is not None and device_mesh is None: device_map_kwargs = { "device_map": device_map, "offload_dir": offload_folder, @@ -4398,6 +4446,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): dispatch_model(model, **device_map_kwargs) + # This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is + # not part of the state_dict (persistent=False) + if device_mesh is not None: + for buffer in model.buffers(): + if buffer.device != tp_device: + buffer.data = buffer.to(tp_device) + if hf_quantizer is not None: hf_quantizer.postprocess_model(model, config=config) model.hf_quantizer = hf_quantizer @@ -4420,16 +4475,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix } return model, loading_info - if tp_plan is not None: - assert tp_device is not None, "tp_device not set!" - if not model.supports_tp_plan: - raise NotImplementedError("This model does not have a tensor parallel plan.") - # Assuming sharding the model onto the world - world_size = torch.distributed.get_world_size() - device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,)) - # Apply Tensor Parallelism - model.tensor_parallel(device_mesh) - return model @staticmethod @@ -4523,6 +4568,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix keep_in_fp32_modules=None, gguf_path=None, weights_only=True, + device_mesh=None, ): is_safetensors = False is_quantized = hf_quantizer is not None @@ -4822,6 +4868,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix is_safetensors=is_safetensors, keep_in_fp32_modules=keep_in_fp32_modules, unexpected_keys=unexpected_keys, + device_mesh=device_mesh, ) else: # Sharded checkpoint or whole but low_cpu_mem_usage==True @@ -4911,6 +4958,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix is_safetensors=is_safetensors, keep_in_fp32_modules=keep_in_fp32_modules, unexpected_keys=unexpected_keys, + device_mesh=device_mesh, ) error_msgs += new_error_msgs else: @@ -5188,7 +5236,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix def tensor_parallel(self, device_mesh): """ - Tensor parallelize the model across the given device mesh. + Tensor parallelize the model across the given device mesh. This function is a helper to be called after the model + was already loaded in memory, note however that this means that each process will first initialize the whole model, + then parallelize it accross devices. Thus there is a huge waste of GPU memory, and this can lead to OOM at loading time. + + Calling `from_pretrained(..., tp_plan="auto")` is prefered, and will parallelize module-by-module during initialization, + so that the expected per-device memory spike at loading time is not larger than the final model size on each device. Args: device_mesh (`torch.distributed.DeviceMesh`): diff --git a/tests/tp/test_tp.py b/tests/tp/test_tp.py index 7b9bff5f16..6a564e5524 100644 --- a/tests/tp/test_tp.py +++ b/tests/tp/test_tp.py @@ -81,17 +81,13 @@ class TestTensorParallel(TestCasePlus): model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, tp_plan="auto") torch.distributed.barrier() - # The expected full model memory footprint - expected_model_memory = 16 + # The expected model memory footprint. We add 1 as not all the modules are split (e.g. the embeddings) + expected_model_memory_per_device = (16 / world_size) + 1 overhead_factor = 1.2 - # Assert we did not use more than the full model expected memory (with some overhead) - if not torch.cuda.max_memory_allocated(device) / 1024**3 < expected_model_memory * overhead_factor: - raise ValueError("Loading the model used more than the full model size") - - # Assert we correctly handled the sharding between devices - if not torch.cuda.memory_allocated(device) / 1024**3 < (expected_model_memory / world_size) * overhead_factor: - raise ValueError("Each model shard is larger than what is expected.") + # Check that we do not use more than the expected sharded size during initialization + if torch.cuda.max_memory_allocated(device) / 1024**3 > expected_model_memory_per_device * overhead_factor: + raise ValueError("Loading the model used more than the expected fraction of model size per device") torch.distributed.barrier() torch.distributed.destroy_process_group()