From 9270ab082740a55344a851049a0b69673b6cbdc5 Mon Sep 17 00:00:00 2001
From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Date: Wed, 6 Dec 2023 17:22:32 +0100
Subject: [PATCH] [`Flash Attention 2`] Add flash attention 2 for GPT-Neo-X
(#26463)
* add flash-attn-2 support for GPT-neo-x
* fixup
* add comment
* revert
* fixes
* update docs
* comment
* again
* fix copies
* add plot + fix copies
* Update docs/source/en/model_doc/gpt_neox.md
---
docs/source/en/model_doc/gpt_neox.md | 34 +++
.../models/gpt_neox/modeling_gpt_neox.py | 279 +++++++++++++++++-
2 files changed, 298 insertions(+), 15 deletions(-)
diff --git a/docs/source/en/model_doc/gpt_neox.md b/docs/source/en/model_doc/gpt_neox.md
index 300001ad5b..1885d44450 100644
--- a/docs/source/en/model_doc/gpt_neox.md
+++ b/docs/source/en/model_doc/gpt_neox.md
@@ -61,6 +61,40 @@ The `generate()` method can be used to generate text using GPT Neo model.
>>> gen_text = tokenizer.batch_decode(gen_tokens)[0]
```
+## Using Flash Attention 2
+
+Flash Attention 2 is an faster, optimized version of the model.
+
+### Installation
+
+First, check whether your hardware is compatible with Flash Attention 2. The latest list of compatible hardware can be found in the [official documentation](https://github.com/Dao-AILab/flash-attention#installation-and-features). If your hardware is not compatible with Flash Attention 2, you can still benefit from attention kernel optimisations through Better Transformer support covered [above](https://huggingface.co/docs/transformers/main/en/model_doc/bark#using-better-transformer).
+
+Next, [install](https://github.com/Dao-AILab/flash-attention#installation-and-features) the latest version of Flash Attention 2:
+
+```bash
+pip install -U flash-attn --no-build-isolation
+```
+
+### Usage
+
+To load a model using Flash Attention 2, we can pass the `use_flash_attention_2` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
+
+```python
+>>> from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast
+
+model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", torch_dtype=torch.float16, use_flash_attention_2=True).to(device)
+...
+```
+
+
+### Expected speedups
+
+Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using `stockmark/gpt-neox-japanese-1.4b` checkpoint and the Flash Attention 2 version of the model using a sequence length of 2048.
+
+
+

+
+
## Resources
- [Causal language modeling task guide](../tasks/language_modeling)
diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py
index 4d783ce5d5..30feda146e 100755
--- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py
+++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py
@@ -20,6 +20,7 @@ import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from torch.nn import functional as F
from ...activations import ACT2FN
from ...file_utils import (
@@ -36,10 +37,15 @@ from ...modeling_outputs import (
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
-from ...utils import logging
+from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging
from .configuration_gpt_neox import GPTNeoXConfig
+if is_flash_attn_2_available():
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+
+
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "trl-internal-testing/tiny-random-GPTNeoXForCausalLM"
@@ -52,6 +58,19 @@ GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
+# Copied from transformers.models.llama.modeling_llama._get_unpad_data
+def _get_unpad_data(attention_mask):
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
class GPTNeoXPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@@ -63,6 +82,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["GPTNeoXLayer"]
_skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
def _init_weights(self, module):
"""Initialize the weights"""
@@ -100,6 +120,7 @@ class GPTNeoXAttention(nn.Module):
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.attention_dropout = nn.Dropout(config.attention_dropout)
+ self.is_causal = True
def _init_bias(self, max_positions, device=None):
self.register_buffer(
@@ -146,6 +167,7 @@ class GPTNeoXAttention(nn.Module):
layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
+ padding_mask: Optional[torch.Tensor] = None,
):
has_layer_past = layer_past is not None
@@ -277,6 +299,226 @@ class GPTNeoXAttention(nn.Module):
return attn_output, attn_weights
+class GPTNeoXFlashAttention2(GPTNeoXAttention):
+ """
+ GPTNeoX flash attention module. This module inherits from `GPTNeoXAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: torch.FloatTensor,
+ position_ids: torch.LongTensor,
+ head_mask: Optional[torch.FloatTensor] = None,
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ):
+ has_layer_past = layer_past is not None
+
+ # Compute QKV
+ # Attention heads [batch, seq_len, hidden_size]
+ # --> [batch, seq_len, (np * 3 * head_size)]
+ qkv = self.query_key_value(hidden_states)
+
+ # [batch, seq_len, (num_heads * 3 * head_size)]
+ # --> [batch, seq_len, num_heads, 3 * head_size]
+ new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
+ qkv = qkv.view(*new_qkv_shape)
+
+ # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
+ query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
+ key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
+ value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
+
+ query_length = query.shape[-2]
+
+ # Compute rotary embeddings on rotary_ndims
+ query_rot = query[..., : self.rotary_ndims]
+ query_pass = query[..., self.rotary_ndims :]
+ key_rot = key[..., : self.rotary_ndims]
+ key_pass = key[..., self.rotary_ndims :]
+
+ # Compute token offset for rotary embeddings (when decoding)
+ seq_len = key.shape[-2]
+ if has_layer_past:
+ seq_len += layer_past[0].shape[-2]
+ cos, sin = self.rotary_emb(value, seq_len=seq_len)
+ query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
+ query = torch.cat((query, query_pass), dim=-1)
+ key = torch.cat((key, key_pass), dim=-1)
+
+ # Cache QKV values
+ if has_layer_past:
+ past_key = layer_past[0]
+ past_value = layer_past[1]
+ key = torch.cat((past_key, key), dim=-2)
+ value = torch.cat((past_value, value), dim=-2)
+ present = (key, value) if use_cache else None
+
+ # GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision
+ target_dtype = value.dtype
+ if query.dtype != target_dtype:
+ query = query.to(target_dtype)
+ if key.dtype != target_dtype:
+ key = key.to(target_dtype)
+
+ # Permute to get the expected shape for Flash Attention
+ query = query.permute(0, 2, 1, 3)
+ key = key.permute(0, 2, 1, 3)
+ value = value.permute(0, 2, 1, 3)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 / bfloat16 just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ input_dtype = query.dtype
+ if input_dtype == torch.float32:
+ # Handle the case where the model is quantized
+ if hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query = query.to(target_dtype)
+ key = key.to(target_dtype)
+ value = value.to(target_dtype)
+
+ attention_dropout = self.config.attention_dropout if self.training else 0.0
+
+ # Compute attention
+ attn_weights = self._flash_attention_forward(
+ query, key, value, attention_mask, query_length, dropout=attention_dropout, softmax_scale=self.norm_factor
+ )
+
+ # Reshape outputs
+ attn_output = attn_weights.reshape(
+ attn_weights.shape[0], attn_weights.shape[1], self.num_attention_heads * self.head_size
+ )
+ attn_output = self.dense(attn_output)
+
+ outputs = (attn_output, present)
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
+ def _flash_attention_forward(
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
+ ):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`int`, *optional*):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ """
+ if not self._flash_attn_uses_top_left_mask:
+ causal = self.is_causal
+ else:
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
+ causal = self.is_causal and query_length != 1
+
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None:
+ batch_size = query_states.shape[0]
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
+ query_states, key_states, value_states, attention_mask, query_length
+ )
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+ else:
+ attn_output = flash_attn_func(
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
+ )
+
+ return attn_output
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input with num_heads->num_attention_heads
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = index_first_axis(
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+ value_layer = index_first_axis(
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(
+ query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads, head_dim), indices_k
+ )
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
def attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min)
return attention_scores
@@ -424,7 +666,11 @@ class GPTNeoXLayer(nn.Module):
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_dropout = nn.Dropout(config.hidden_dropout)
self.post_mlp_dropout = nn.Dropout(config.hidden_dropout)
- self.attention = GPTNeoXAttention(config)
+ self.attention = (
+ GPTNeoXAttention(config)
+ if not getattr(config, "_flash_attn_2_enabled", False)
+ else GPTNeoXFlashAttention2(config)
+ )
self.mlp = GPTNeoXMLP(config)
def forward(
@@ -615,20 +861,23 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
if attention_mask is not None:
assert batch_size > 0, "batch_size has to be defined and > 0"
attention_mask = attention_mask.view(batch_size, -1)
- # We create a 3D attention mask from a 2D tensor mask.
- # Sizes are [batch_size, 1, 1, to_seq_length]
- # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
- # this attention mask is more simple than the triangular masking of causal attention
- # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
- attention_mask = attention_mask[:, None, None, :]
+ if getattr(self.config, "_flash_attn_2_enabled", False):
+ attention_mask = attention_mask if 0 in attention_mask else None
+ else:
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask = attention_mask[:, None, None, :]
- # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
- # masked positions, this operation will create a tensor which is 0.0 for
- # positions we want to attend and the dtype's smallest value for masked positions.
- # Since we are adding it to the raw scores before the softmax, this is
- # effectively the same as removing these entirely.
- attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
- attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and the dtype's smallest value for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head