From 1da1302ec8832e58510801cbbfd506194d8ce7ea Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 4 Dec 2023 13:52:17 +0100 Subject: [PATCH] Flash Attention 2 support for RoCm (#27611) * support FA2 * fix typo * fix broken tests * fix more test errors * left/right * fix bug * more test * typo * fix layout flash attention falcon * do not support this case * use allclose instead of equal * fix various bugs with flash attention * bump * fix test * fix mistral * use skiptest instead of return that may be misleading * add fix causal arg flash attention * fix copies * more explicit comment * still use self.is_causal * fix causal argument * comment * fixes * update documentation * add link * wrong test * simplify FA2 RoCm requirements * update opt * make flash_attn_uses_top_left_mask attribute private and precise comment * better error handling * fix copy & mistral * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/utils/import_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * use is_flash_attn_greater_or_equal_2_10 instead of is_flash_attn_greater_or_equal_210 * fix merge * simplify * inline args --------- Co-authored-by: Felix Marty Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- docs/source/en/model_doc/gpt_neo.md | 8 ++--- docs/source/en/perf_infer_gpu_one.md | 8 ++--- src/transformers/modeling_utils.py | 34 +++++++++++++------ src/transformers/models/bark/modeling_bark.py | 20 +++++++++-- src/transformers/models/bart/modeling_bart.py | 20 +++++++++-- .../models/distilbert/modeling_distilbert.py | 20 +++++++++-- .../models/falcon/modeling_falcon.py | 20 +++++++++-- .../gpt_bigcode/modeling_gpt_bigcode.py | 20 +++++++++-- .../models/gpt_neo/modeling_gpt_neo.py | 20 +++++++++-- .../models/llama/modeling_llama.py | 21 ++++++++++-- .../models/mbart/modeling_mbart.py | 20 +++++++++-- .../models/mistral/modeling_mistral.py | 24 ++++++++++--- src/transformers/models/opt/modeling_opt.py | 20 +++++++++-- .../models/whisper/modeling_whisper.py | 20 +++++++++-- src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 24 ++++++++++--- tests/test_modeling_common.py | 4 +-- 17 files changed, 253 insertions(+), 51 deletions(-) diff --git a/docs/source/en/model_doc/gpt_neo.md b/docs/source/en/model_doc/gpt_neo.md index fb2385bc73..96b6a8c96f 100644 --- a/docs/source/en/model_doc/gpt_neo.md +++ b/docs/source/en/model_doc/gpt_neo.md @@ -56,13 +56,9 @@ The `generate()` method can be used to generate text using GPT Neo model. ## Combining GPT-Neo and Flash Attention 2 -First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature. +First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature, and make sure your hardware is compatible with Flash-Attention 2. More details are available [here](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2) concerning the installation. -```bash -pip install -U flash-attn --no-build-isolation -``` - -Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16``) +Make sure as well to load your model in half-precision (e.g. `torch.float16`). To load and run a model using Flash Attention 2, refer to the snippet below: diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 82ec39441f..d91ed2094f 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -38,11 +38,9 @@ FlashAttention-2 is experimental and may change considerably in future versions. FlashAttention-2 supports inference with Llama, Mistral, Falcon and Bark models. You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request. -Before you begin, make sure you have FlashAttention-2 installed (see the [installation](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features) guide for more details about prerequisites): +Before you begin, make sure you have FlashAttention-2 installed. For NVIDIA GPUs, the library is installable through pip: `pip install flash-attn --no-build-isolation`. We strongly suggest to refer to the [detailed installation instructions](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features). -```bash -pip install flash-attn --no-build-isolation -``` +FlashAttention-2 is also supported on AMD GPUs, with the current support limited to **Instinct MI210 and Instinct MI250**. We strongly suggest to use the following [Dockerfile](https://github.com/huggingface/optimum-amd/tree/main/docker/transformers-pytorch-amd-gpu-flash/Dockerfile) to use FlashAttention-2 on AMD GPUs. To enable FlashAttention-2, add the `use_flash_attention_2` parameter to [`~AutoModelForCausalLM.from_pretrained`]: @@ -62,7 +60,7 @@ model = AutoModelForCausalLM.from_pretrained( -FlashAttention-2 can only be used when the model's dtype is `fp16` or `bf16`, and it only runs on Nvidia GPUs. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-2. +FlashAttention-2 can only be used when the model's dtype is `fp16` or `bf16`. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-2. diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 1dba62efe4..4d1178bc68 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1281,17 +1281,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) if not is_flash_attn_2_available(): - raise ImportError( - "Flash Attention 2 is not available. Please refer to the documentation of https://github.com/Dao-AILab/flash-attention for" - " installing it. Make sure to have at least the version 2.1.0" - ) - else: flash_attention_version = version.parse(importlib.metadata.version("flash_attn")) - is_flash_greater_than_2 = flash_attention_version >= version.parse("2.1.0") - if not is_flash_greater_than_2: - raise ValueError( - f"You need flash_attn package version to be greater or equal than 2.1. Make sure to have that version installed - detected version {flash_attention_version}" - ) + + preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:" + install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2." + if torch.version.cuda: + if importlib.util.find_spec("flash_attn") is None: + raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}") + + flash_attention_version = version.parse(importlib.metadata.version("flash_attn")) + if flash_attention_version < version.parse("2.1.0"): + raise ImportError( + f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}" + ) + else: + raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}") + elif torch.version.hip: + if importlib.util.find_spec("flash_attn") is None: + raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}") + flash_attention_version = version.parse(importlib.metadata.version("flash_attn")) + if flash_attention_version < version.parse("2.0.4"): + raise ImportError( + f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Make sure to have that version installed - detected version {flash_attention_version}. {install_message}" + ) + else: + raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}") _is_bettertransformer = getattr(cls, "use_bettertransformer", False) diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index f8b9eab5d3..d472b85474 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -34,6 +34,7 @@ from ...utils import ( add_start_docstrings_to_model_forward, is_accelerate_available, is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, logging, ) from ..auto import AutoModel @@ -214,6 +215,15 @@ class BarkSelfFlashAttention2(BarkSelfAttention): flash attention and deal with padding tokens in case the input contains any of them. """ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + def _split_heads(self, tensor, num_heads, attn_head_size): """ Splits hidden_size dim into attn_head_size and num_heads @@ -301,6 +311,12 @@ class BarkSelfFlashAttention2(BarkSelfAttention): softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] @@ -321,13 +337,13 @@ class BarkSelfFlashAttention2(BarkSelfAttention): max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=self.is_causal, + causal=causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal ) return attn_output diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index efca985f67..a71f79b301 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -42,6 +42,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -294,6 +295,15 @@ class BartFlashAttention2(BartAttention): flash attention and deal with padding tokens in case the input contains any of them. """ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) @@ -418,6 +428,12 @@ class BartFlashAttention2(BartAttention): softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] @@ -438,13 +454,13 @@ class BartFlashAttention2(BartAttention): max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=self.is_causal, + causal=causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal ) return attn_output diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 144fde42e0..2e58d1728e 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -46,6 +46,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -269,6 +270,15 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention): API of flash attention and deal with padding tokens in case the input contains any of them. """ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + def forward( self, query: torch.Tensor, @@ -363,6 +373,12 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention): softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] @@ -383,13 +399,13 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention): max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=self.is_causal, + causal=causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal ) return attn_output diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 3e46903b9c..01d9f0c3ed 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -38,6 +38,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, logging, ) from .configuration_falcon import FalconConfig @@ -516,6 +517,15 @@ class FalconFlashAttention2(FalconAttention): flash attention and deal with padding tokens in case the input contains any of them. """ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + def forward( self, hidden_states: torch.Tensor, @@ -631,6 +641,12 @@ class FalconFlashAttention2(FalconAttention): softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] @@ -651,13 +667,13 @@ class FalconFlashAttention2(FalconAttention): max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=self.is_causal, + causal=causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal ) return attn_output diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 2d12d59097..45c96146a1 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -34,6 +34,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, logging, ) from .configuration_gpt_bigcode import GPTBigCodeConfig @@ -292,6 +293,15 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention): API of flash attention and deal with padding tokens in case the input contains any of them. """ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + def forward( self, hidden_states: torch.Tensor, @@ -422,6 +432,12 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention): softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] @@ -442,13 +458,13 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention): max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=self.is_causal, + causal=causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal ) return attn_output diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 7cc3bef70f..1089322cf9 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -42,6 +42,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, is_torch_fx_available, logging, ) @@ -299,6 +300,15 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention): flash attention and deal with padding tokens in case the input contains any of them. """ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + def forward( self, hidden_states, @@ -400,6 +410,12 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention): softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] @@ -420,13 +436,13 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention): max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=self.is_causal, + causal=causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal ) return attn_output diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index e53f89276f..583fc8064e 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -37,6 +37,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -442,6 +443,14 @@ class LlamaFlashAttention2(LlamaAttention): flash attention and deal with padding tokens in case the input contains any of them. """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + def forward( self, hidden_states: torch.Tensor, @@ -491,6 +500,8 @@ class LlamaFlashAttention2(LlamaAttention): past_key_value = (key_states, value_states) if use_cache else None + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) @@ -555,6 +566,12 @@ class LlamaFlashAttention2(LlamaAttention): softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] @@ -575,13 +592,13 @@ class LlamaFlashAttention2(LlamaAttention): max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=self.is_causal, + causal=causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal ) return attn_output diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 97fdf9ed87..dab8d4dae1 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -41,6 +41,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -289,6 +290,15 @@ class MBartFlashAttention2(MBartAttention): flash attention and deal with padding tokens in case the input contains any of them. """ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) @@ -413,6 +423,12 @@ class MBartFlashAttention2(MBartAttention): softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] @@ -433,13 +449,13 @@ class MBartFlashAttention2(MBartAttention): max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=self.is_causal, + causal=causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal ) return attn_output diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index d1c67880b6..0b23303d5e 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -37,6 +37,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -312,6 +313,15 @@ class MistralFlashAttention2(MistralAttention): flash attention and deal with padding tokens in case the input contains any of them. """ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + def forward( self, hidden_states: torch.Tensor, @@ -470,6 +480,12 @@ class MistralFlashAttention2(MistralAttention): use_sliding_windows (`bool`, *optional*): Whether to activate sliding window attention. """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] @@ -491,7 +507,7 @@ class MistralFlashAttention2(MistralAttention): max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=self.is_causal, + causal=causal, ) else: attn_output_unpad = flash_attn_varlen_func( @@ -504,7 +520,7 @@ class MistralFlashAttention2(MistralAttention): max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=self.is_causal, + causal=causal, window_size=(self.config.sliding_window, self.config.sliding_window), ) @@ -517,7 +533,7 @@ class MistralFlashAttention2(MistralAttention): value_states, dropout, softmax_scale=softmax_scale, - causal=self.is_causal, + causal=causal, ) else: attn_output = flash_attn_func( @@ -526,7 +542,7 @@ class MistralFlashAttention2(MistralAttention): value_states, dropout, softmax_scale=softmax_scale, - causal=self.is_causal, + causal=causal, window_size=(self.config.sliding_window, self.config.sliding_window), ) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 2192f327bc..7d6b71dc09 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -35,6 +35,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -288,6 +289,15 @@ class OptFlashAttention2(OPTAttention): attention and deal with padding tokens in case the input contains any of them. """ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + def forward( self, hidden_states: torch.Tensor, @@ -404,6 +414,12 @@ class OptFlashAttention2(OPTAttention): softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] @@ -424,13 +440,13 @@ class OptFlashAttention2(OPTAttention): max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=self.is_causal, + causal=causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal ) return attn_output diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index e88fe3a6aa..aeb1a0284f 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -41,6 +41,7 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -478,6 +479,15 @@ class WhisperFlashAttention2(WhisperAttention): flash attention and deal with padding tokens in case the input contains any of them. """ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) @@ -602,6 +612,12 @@ class WhisperFlashAttention2(WhisperAttention): softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] @@ -622,13 +638,13 @@ class WhisperFlashAttention2(WhisperAttention): max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=self.is_causal, + causal=causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal ) return attn_output diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index e7911e5d55..719f78af2a 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -118,6 +118,7 @@ from .import_utils import ( is_faiss_available, is_flash_attn_2_available, is_flash_attn_available, + is_flash_attn_greater_or_equal_2_10, is_flax_available, is_fsdp_available, is_ftfy_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index beb6c47795..2da0dbc891 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -71,9 +71,6 @@ TORCH_FX_REQUIRED_VERSION = version.parse("1.10") _accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True) _apex_available = _is_package_available("apex") _bitsandbytes_available = _is_package_available("bitsandbytes") -_flash_attn_2_available = _is_package_available("flash_attn") and version.parse( - importlib.metadata.version("flash_attn") -) >= version.parse("2.1.0") # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. _bs4_available = importlib.util.find_spec("bs4") is not None _coloredlogs_available = _is_package_available("coloredlogs") @@ -608,10 +605,29 @@ def is_flash_attn_2_available(): if not is_torch_available(): return False + if not _is_package_available("flash_attn"): + return False + # Let's add an extra check to see if cuda is available import torch - return _flash_attn_2_available and torch.cuda.is_available() + if not torch.cuda.is_available(): + return False + + if torch.version.cuda: + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0") + elif torch.version.hip: + # TODO: Bump the requirement to 2.1.0 once released in https://github.com/ROCmSoftwarePlatform/flash-attention + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.0.4") + else: + return False + + +def is_flash_attn_greater_or_equal_2_10(): + if not _is_package_available("flash_attn"): + return False + + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0") def is_flash_attn_available(): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index b76a8025e6..334e860452 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3087,7 +3087,7 @@ class ModelTesterMixin: dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False ) - self.assertTrue(torch.equal(out, out_fa)) + self.assertTrue(torch.allclose(out, out_fa)) @require_flash_attn @require_torch_gpu @@ -3130,7 +3130,7 @@ class ModelTesterMixin: dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False ) - self.assertTrue(torch.equal(out, out_fa)) + self.assertTrue(torch.allclose(out, out_fa)) @require_flash_attn @require_torch_gpu