From 350c5d15668aa2f0467699baa795ef751b1167d3 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Wed, 13 Mar 2024 22:03:02 +0530 Subject: [PATCH] Add support for FSDP+QLoRA and DeepSpeed ZeRO3+QLoRA (#29587) * fsdp+qlora related changes * fixes * Update quantization_config.py * support fsdp+qlora and dsz3+qlora * Update quantization_config.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * handle fsdp+qlora and dsz3+qlora correctly while model loading * fix param count * quality * fsdp related changes * fsdp changes only when using LoRA/QLoRA * add accelerate version check * refactor, update min accelerate version and add tests 1. Update minimum accelerate version to 0.26.0 2. Clean the trainer wrt accelerate version checks 3. FSDP refactor and test for fsdp config 4. use `itemsize` instead of `dtype2bytes` dict * fix test * Address comments Co-Authored-By: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * fix the conditional flag * fix conditional flag * address comments Co-Authored-By: Zach Mueller <7831895+muellerzr@users.noreply.github.com> --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Zach Mueller <7831895+muellerzr@users.noreply.github.com> --- src/transformers/integrations/bitsandbytes.py | 7 +++ src/transformers/modeling_utils.py | 60 ++++++++++++------- src/transformers/trainer.py | 18 ++++++ src/transformers/training_args.py | 6 +- src/transformers/utils/quantization_config.py | 13 ++++ tests/fsdp/test_fsdp.py | 39 ++++++++++++ 6 files changed, 119 insertions(+), 24 deletions(-) diff --git a/src/transformers/integrations/bitsandbytes.py b/src/transformers/integrations/bitsandbytes.py index d58e749f82..e038768b97 100644 --- a/src/transformers/integrations/bitsandbytes.py +++ b/src/transformers/integrations/bitsandbytes.py @@ -1,6 +1,7 @@ import importlib.metadata import warnings from copy import deepcopy +from inspect import signature from packaging import version @@ -179,6 +180,11 @@ def _replace_with_bnb_linear( ): pass else: + extra_kwargs = ( + {"quant_storage": quantization_config.bnb_4bit_quant_storage} + if "quant_storage" in list(signature(bnb.nn.Linear4bit).parameters) + else {} + ) model._modules[name] = bnb.nn.Linear4bit( in_features, out_features, @@ -186,6 +192,7 @@ def _replace_with_bnb_linear( quantization_config.bnb_4bit_compute_dtype, compress_statistics=quantization_config.bnb_4bit_use_double_quant, quant_type=quantization_config.bnb_4bit_quant_type, + **extra_kwargs, ) has_been_replaced = True # Store the module class in case we need to transpose the weight later diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 019fc1f412..2868343bfc 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -54,6 +54,7 @@ from .pytorch_utils import ( # noqa: F401 prune_linear_layer, ) from .quantizers import AutoHfQuantizer, HfQuantizer +from .quantizers.quantizers_utils import get_module_from_name from .safetensors_conversion import auto_conversion from .utils import ( ADAPTER_SAFE_WEIGHTS_NAME, @@ -496,7 +497,7 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys) -def load_state_dict(checkpoint_file: Union[str, os.PathLike]): +def load_state_dict(checkpoint_file: Union[str, os.PathLike], is_quantized: bool = False): """ Reads a PyTorch checkpoint file, returning properly formatted errors if they arise. """ @@ -512,8 +513,9 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): return safe_load_file(checkpoint_file) try: if ( - is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0 - ) or (is_fsdp_enabled() and not is_local_dist_rank_0()): + (is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0) + or (is_fsdp_enabled() and not is_local_dist_rank_0()) + ) and not is_quantized: map_location = "meta" else: map_location = "cpu" @@ -718,6 +720,7 @@ def _load_state_dict_into_meta_model( old_keys = [] new_keys = [] + is_quantized = hf_quantizer is not None for key in state_dict.keys(): new_key = None if "gamma" in key: @@ -797,7 +800,7 @@ def _load_state_dict_into_meta_model( elif param_device == "cpu" and state_dict_index is not None: state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index) elif ( - hf_quantizer is None + not is_quantized or (not hf_quantizer.requires_parameters_quantization) or (not hf_quantizer.check_quantized_param(model, param, param_name, state_dict)) ): @@ -805,6 +808,14 @@ def _load_state_dict_into_meta_model( set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs) else: hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys) + # For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU + # and then cast it to CPU to avoid excessive memory usage on each GPU + # in comparison to the sharded model across GPUs. + if is_fsdp_enabled() or is_deepspeed_zero3_enabled(): + module, tensor_name = get_module_from_name(model, param_name) + value = getattr(module, tensor_name) + value = type(value)(value.data.to("cpu"), **value.__dict__) + setattr(module, tensor_name, value) # TODO: consider removing used param_parts from state_dict before return return error_msgs, offload_index, state_dict_index @@ -1070,7 +1081,9 @@ class ModuleUtilsMixin: # For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are # used for the 4bit quantization (uint8 tensors are stored) if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit): - total_numel.append(param.numel() * 2) + total_numel.append( + param.numel() * 2 * self.hf_quantizer.quantization_config.bnb_4bit_quant_storage.itemsize + ) else: total_numel.append(param.numel()) @@ -1805,10 +1818,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix old_embeddings_requires_grad = old_embeddings.weight.requires_grad new_embeddings.requires_grad_(old_embeddings_requires_grad) self.set_input_embeddings(new_embeddings) + is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None # Update new_num_tokens with the actual size of new_embeddings if pad_to_multiple_of is not None: - if is_deepspeed_zero3_enabled(): + if is_deepspeed_zero3_enabled() and not is_quantized: import deepspeed with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None): @@ -1882,7 +1896,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if new_num_tokens is None: return old_embeddings - if is_deepspeed_zero3_enabled(): + is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None + if is_deepspeed_zero3_enabled() and not is_quantized: import deepspeed with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None): @@ -1921,7 +1936,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # numbers of tokens to copy n = min(old_num_tokens, new_num_tokens) - if is_deepspeed_zero3_enabled(): + if is_deepspeed_zero3_enabled() and not is_quantized: import deepspeed params = [old_embeddings.weight, new_embeddings.weight] @@ -1958,7 +1973,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if new_num_tokens is None: return old_lm_head - if is_deepspeed_zero3_enabled(): + is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None + if is_deepspeed_zero3_enabled() and not is_quantized: import deepspeed with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None): @@ -2000,7 +2016,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix num_tokens_to_copy = min(old_num_tokens, new_num_tokens) - if is_deepspeed_zero3_enabled(): + if is_deepspeed_zero3_enabled() and not is_quantized: import deepspeed params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias] @@ -3036,6 +3052,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if low_cpu_mem_usage is None: low_cpu_mem_usage = True logger.warning("`low_cpu_mem_usage` was None, now set to True since model is quantized.") + is_quantized = hf_quantizer is not None # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the # index of the files. @@ -3365,7 +3382,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Instantiate model. init_contexts = [no_init_weights(_enable=_fast_init)] - if is_deepspeed_zero3_enabled(): + if is_deepspeed_zero3_enabled() and not is_quantized: import deepspeed logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") @@ -3564,7 +3581,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix } if "skip_keys" in inspect.signature(dispatch_model).parameters: device_map_kwargs["skip_keys"] = model._skip_keys_device_placement - dispatch_model(model, **device_map_kwargs) + if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled(): + dispatch_model(model, **device_map_kwargs) if hf_quantizer is not None: hf_quantizer.postprocess_model(model) @@ -3610,6 +3628,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix keep_in_fp32_modules=None, ): is_safetensors = False + is_quantized = hf_quantizer is not None if device_map is not None and "disk" in device_map.values(): archive_file = ( @@ -3735,7 +3754,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if param.device == torch.device("meta"): value = torch.empty(*param.size(), dtype=target_dtype) if ( - hf_quantizer is None + not is_quantized or getattr(hf_quantizer, "requires_parameters_quantization", False) or not hf_quantizer.check_quantized_param( model, param_value=value, param_name=key, state_dict={} @@ -3765,7 +3784,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix else: not_initialized_submodules = dict(model.named_modules()) # This will only initialize submodules that are not marked as initialized by the line above. - if is_deepspeed_zero3_enabled(): + if is_deepspeed_zero3_enabled() and not is_quantized: import deepspeed not_initialized_parameters = list( @@ -3909,7 +3928,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload. if shard_file in disk_only_shard_files: continue - state_dict = load_state_dict(shard_file) + state_dict = load_state_dict(shard_file, is_quantized=is_quantized) # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # matching the weights in the model. @@ -3922,15 +3941,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ignore_mismatched_sizes, ) if low_cpu_mem_usage: - if is_fsdp_enabled() and not is_local_dist_rank_0(): + if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized: for key, param in model_to_load.state_dict().items(): if param.device == torch.device("meta"): - if hf_quantizer is None: - set_module_tensor_to_device( - model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype) - ) - else: - hf_quantizer.create_quantized_param(model, param, key, "cpu", state_dict) + set_module_tensor_to_device( + model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype) + ) else: new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( model_to_load, diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8836e0be21..e73f305022 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1776,6 +1776,7 @@ class Trainer: if delay_optimizer_creation: if use_accelerator_prepare: + self._fsdp_qlora_plugin_updates() self.model = self.accelerator.prepare(self.model) self.create_optimizer_and_scheduler(num_training_steps=max_steps) @@ -4156,3 +4157,20 @@ class Trainer: ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config) ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config ds_plugin.hf_ds_config.trainer_config_process(self.args, auto_find_batch_size) + + def _fsdp_qlora_plugin_updates(self): + if self.is_fsdp_enabled and _is_peft_model(self.model): + from peft import LoraConfig + from peft.utils.other import fsdp_auto_wrap_policy + + if isinstance(self.model.active_peft_config, LoraConfig): + fsdp_plugin = self.accelerator.state.fsdp_plugin + fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model) + if ( + getattr(self.model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + and self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage.is_floating_point + and version.parse(accelerate_version) > version.parse("0.27.0") + ): + fsdp_plugin.set_mixed_precision( + self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True + ) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 15ddc88dcb..54cd045b20 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1721,8 +1721,10 @@ class TrainingArguments: for fsdp_option in self.fsdp: if fsdp_option.upper() in FSDP_SHARDING_STRATEGY: # set environment variable for FSDP sharding strategy - os.environ[f"{prefix}SHARDING_STRATEGY"] = str( - FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1 + os.environ[f"{prefix}SHARDING_STRATEGY"] = ( + str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1) + if is_accelerate_available("0.26.0") + else fsdp_option.upper() ) elif fsdp_option == FSDPOption.OFFLOAD: os.environ[f"{prefix}OFFLOAD_PARAMS"] = "true" diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index a29886d8c6..3756079ca0 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -225,6 +225,8 @@ class BitsAndBytesConfig(QuantizationConfigMixin): bnb_4bit_use_double_quant (`bool`, *optional*, defaults to `False`): This flag is used for nested quantization where the quantization constants from the first quantization are quantized again. + bnb_4bit_quant_storage (`torch.dtype` or str, *optional*, defaults to `torch.uint8`): + This sets the storage type to pack the quanitzed 4-bit prarams. kwargs (`Dict[str, Any]`, *optional*): Additional parameters from which to initialize the configuration object. """ @@ -240,6 +242,7 @@ class BitsAndBytesConfig(QuantizationConfigMixin): bnb_4bit_compute_dtype=None, bnb_4bit_quant_type="fp4", bnb_4bit_use_double_quant=False, + bnb_4bit_quant_storage=None, **kwargs, ): self.quant_method = QuantizationMethod.BITS_AND_BYTES @@ -265,6 +268,15 @@ class BitsAndBytesConfig(QuantizationConfigMixin): else: raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype") + if bnb_4bit_quant_storage is None: + self.bnb_4bit_quant_storage = torch.uint8 + elif isinstance(bnb_4bit_quant_storage, str): + self.bnb_4bit_quant_storage = getattr(torch, bnb_4bit_quant_storage) + elif isinstance(bnb_4bit_quant_storage, torch.dtype): + self.bnb_4bit_quant_storage = bnb_4bit_quant_storage + else: + raise ValueError("bnb_4bit_quant_storage must be a string or a torch.dtype") + self.post_init() @property @@ -345,6 +357,7 @@ class BitsAndBytesConfig(QuantizationConfigMixin): """ output = copy.deepcopy(self.__dict__) output["bnb_4bit_compute_dtype"] = str(output["bnb_4bit_compute_dtype"]).split(".")[1] + output["bnb_4bit_quant_storage"] = str(output["bnb_4bit_quant_storage"]).split(".")[1] output["load_in_4bit"] = self.load_in_4bit output["load_in_8bit"] = self.load_in_8bit diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index aa5b353753..aeb232fd9e 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -15,6 +15,7 @@ import itertools import os import unittest +from copy import deepcopy from functools import partial from parameterized import parameterized @@ -171,6 +172,44 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon): self.assertEqual(v, self.fsdp_config[k]) self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true") + @parameterized.expand(params, name_func=_parameterized_custom_name_func) + def test_fsdp_config_transformers_auto_wrap(self, sharding_strategy, dtype): + output_dir = self.get_auto_remove_tmp_dir() + fsdp_config = deepcopy(self.fsdp_config) + del fsdp_config["min_num_params"] + fsdp_config["transformer_layer_cls_to_wrap"] = "BertLayer" + kwargs = { + "output_dir": output_dir, + "train_len": 128, + "save_steps": 5, + "learning_rate": 0.1, + "fsdp": f"{sharding_strategy} offload auto_wrap", + "fsdp_config": fsdp_config, + } + kwargs[dtype] = True + prefix = "FSDP_" + with mockenv_context(**self.dist_env_1_gpu): + trainer = get_regression_trainer(**kwargs) + self.assertEqual(trainer.args.fsdp[0], sharding_strategy) + self.assertEqual(trainer.args.fsdp[1], FSDPOption.OFFLOAD) + self.assertEqual(trainer.args.fsdp[2], FSDPOption.AUTO_WRAP) + fsdp_sharding_strategy = ( + str(FSDP_SHARDING_STRATEGY.index(sharding_strategy.upper()) + 1) + if is_accelerate_available("0.26.0") + else sharding_strategy.upper() + ) + self.assertEqual(os.environ[f"{prefix}SHARDING_STRATEGY"], fsdp_sharding_strategy) + self.assertEqual(os.environ[f"{prefix}OFFLOAD_PARAMS"], "true") + self.assertEqual(os.environ[f"{prefix}AUTO_WRAP_POLICY"], "TRANSFORMER_BASED_WRAP") + self.assertEqual( + os.environ[f"{prefix}TRANSFORMER_CLS_TO_WRAP"], ",".join(fsdp_config["transformer_layer_cls_to_wrap"]) + ) + self.assertEqual(os.environ[f"{prefix}BACKWARD_PREFETCH"], fsdp_config["backward_prefetch"].upper()) + self.assertEqual(os.environ[f"{prefix}FORWARD_PREFETCH"], fsdp_config["forward_prefetch"]) + self.assertEqual(os.environ[f"{prefix}USE_ORIG_PARAMS"], fsdp_config["use_orig_params"]) + self.assertEqual(os.environ[f"{prefix}SYNC_MODULE_STATES"], fsdp_config["sync_module_states"]) + self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true") + @parameterized.expand(params, name_func=_parameterized_custom_name_func) @require_torch_multi_accelerator @slow