From def581ef51f78326a3f56de9cf9c637c47b920ad Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Wed, 20 Dec 2023 14:25:07 +0530 Subject: [PATCH] Fix FA2 integration (#28142) * fix fa2 * fix FA2 for popular models * improve warning and add Younes as co-author Co-Authored-By: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix the warning * Add Tip * typo fix * nit --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- docs/source/en/model_doc/llama2.md | 2 ++ src/transformers/modeling_utils.py | 6 +++--- src/transformers/models/falcon/modeling_falcon.py | 2 ++ src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 2 ++ src/transformers/models/llama/modeling_llama.py | 2 ++ src/transformers/models/mistral/modeling_mistral.py | 2 ++ src/transformers/models/mixtral/modeling_mixtral.py | 2 ++ 7 files changed, 15 insertions(+), 3 deletions(-) diff --git a/docs/source/en/model_doc/llama2.md b/docs/source/en/model_doc/llama2.md index a817a866c0..b4cd6b9ca1 100644 --- a/docs/source/en/model_doc/llama2.md +++ b/docs/source/en/model_doc/llama2.md @@ -67,6 +67,8 @@ come in several checkpoints they each contain a part of each weight of the model - The LLaMA tokenizer is a BPE model based on [sentencepiece](https://github.com/google/sentencepiece). One quirk of sentencepiece is that when decoding a sequence, if the first token is the start of the word (e.g. "Banana"), the tokenizer does not prepend the prefix space to the string. +- When using Flash Attention 2 via `attn_implementation="flash_attention_2"`, don't pass `torch_dtype` to the `from_pretrained` class method and use Automatic Mixed-Precision training. When using `Trainer`, it is simply specifying either `fp16` or `bf16` to `True`. Otherwise, make sure you are using `torch.autocast`. This is required because the Flash Attention only support `fp16` and `bf16` data type. + ## Resources diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7461501d30..1584e4a544 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1419,9 +1419,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix "You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour" ) elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]: - raise ValueError( - f"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. You passed {torch_dtype}, this might lead to" - " unexpected behaviour." + logger.warning( + "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. " + "No dtype was provided, you should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator." ) # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called, diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 78088bc5d7..b4de486e80 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -620,6 +620,8 @@ class FalconFlashAttention2(FalconAttention): # Handle the case where the model is quantized if hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype + elif torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() else: target_dtype = self.query_key_value.weight.dtype diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index b6b03d07b9..0e21376204 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -378,6 +378,8 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention): # Handle the case where the model is quantized if hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype + elif torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() else: target_dtype = self.c_attn.weight.dtype diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8ceee2d1d4..a97755c51a 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -531,6 +531,8 @@ class LlamaFlashAttention2(LlamaAttention): # Handle the case where the model is quantized if hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype + elif torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() else: target_dtype = self.q_proj.weight.dtype diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index ee51bcea79..6858d5c4b8 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -431,6 +431,8 @@ class MistralFlashAttention2(MistralAttention): # Handle the case where the model is quantized if hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype + elif torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() else: target_dtype = self.q_proj.weight.dtype diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index a7622f149e..1c2febfbac 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -479,6 +479,8 @@ class MixtralFlashAttention2(MixtralAttention): # Handle the case where the model is quantized if hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype + elif torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() else: target_dtype = self.q_proj.weight.dtype