From 0ded281557f88add8733aa536763fae61207b382 Mon Sep 17 00:00:00 2001 From: Susnato Dhar Date: Tue, 7 Nov 2023 19:24:01 +0530 Subject: [PATCH] [`FA2`] Add flash attention for `GPT-Neo` (#26486) * added flash attention for gpt-neo * small change Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * readme updated * . * changes * removed padding_mask * Update src/transformers/models/gpt_neo/modeling_gpt_neo.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- docs/source/en/model_doc/gpt_neo.md | 40 +++ .../models/gpt_neo/modeling_gpt_neo.py | 253 ++++++++++++++++-- 2 files changed, 269 insertions(+), 24 deletions(-) diff --git a/docs/source/en/model_doc/gpt_neo.md b/docs/source/en/model_doc/gpt_neo.md index 6a598ef6a9..fb2385bc73 100644 --- a/docs/source/en/model_doc/gpt_neo.md +++ b/docs/source/en/model_doc/gpt_neo.md @@ -54,6 +54,46 @@ The `generate()` method can be used to generate text using GPT Neo model. >>> gen_text = tokenizer.batch_decode(gen_tokens)[0] ``` +## Combining GPT-Neo and Flash Attention 2 + +First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature. + +```bash +pip install -U flash-attn --no-build-isolation +``` + +Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16``) + +To load and run a model using Flash Attention 2, refer to the snippet below: + +```python +>>> import torch +>>> from transformers import AutoModelForCausalLM, AutoTokenizer +>>> device = "cuda" # the device to load the model onto + +>>> model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", torch_dtype=torch.float16, use_flash_attention_2=True) +>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B") + +>>> prompt = "def hello_world():" + +>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device) +>>> model.to(device) + +>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True) +>>> tokenizer.batch_decode(generated_ids)[0] +"def hello_world():\n >>> run_script("hello.py")\n >>> exit(0)\n<|endoftext|>" +``` + +### Expected speedups + +Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using `EleutherAI/gpt-neo-2.7B` checkpoint and the Flash Attention 2 version of the model. +Note that for GPT-Neo it is not possible to train / run on very long context as the max [position embeddings](https://huggingface.co/EleutherAI/gpt-neo-2.7B/blob/main/config.json#L58 ) is limited to 2048 - but this is applicable to all gpt-neo models and not specific to FA-2 + +
+ +
+ + ## Resources - [Text classification task guide](../tasks/sequence_classification) diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 90ca265a82..aa1f129515 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -19,11 +19,13 @@ import os from typing import Optional, Tuple, Union import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, @@ -34,10 +36,28 @@ from ...modeling_outputs import ( TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_torch_fx_available, + logging, +) from .configuration_gpt_neo import GPTNeoConfig +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "GPTNeoConfig" @@ -50,6 +70,19 @@ GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST = [ _CHECKPOINT_FOR_DOC = "EleutherAI/gpt-neo-1.3B" +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path): """Load tf checkpoints in a pytorch model""" try: @@ -133,6 +166,7 @@ def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path): class GPTNeoSelfAttention(nn.Module): def __init__(self, config, attention_type): super().__init__() + self.config = config max_positions = config.max_position_embeddings bias = torch.tril(torch.ones((max_positions, max_positions), dtype=bool)).view( @@ -150,6 +184,7 @@ class GPTNeoSelfAttention(nn.Module): self.attn_dropout = nn.Dropout(float(config.attention_dropout)) self.resid_dropout = nn.Dropout(float(config.resid_dropout)) + self.is_causal = True self.embed_dim = config.hidden_size self.num_heads = config.num_heads @@ -253,6 +288,185 @@ class GPTNeoSelfAttention(nn.Module): return outputs # a, present, (attentions) +class GPTNeoFlashAttention2(GPTNeoSelfAttention): + """ + GPTNeo flash attention module. This module inherits from `GPTNeoSelfAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states, + attention_mask=None, + layer_past=None, + head_mask=None, + use_cache=False, + output_attentions=False, + ): + bsz, _, _ = hidden_states.size() + + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + query_length = query.shape[2] + tgt_len = key.shape[2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + attn_dropout = self.config.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query.dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = self._flash_attention_forward( + query, key, value, attention_mask, query_length, dropout=attn_dropout, softmax_scale=1.0 + ) + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.out_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights_reshaped,) + + return outputs + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=self.is_causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + class GPTNeoAttention(nn.Module): def __init__(self, config, layer_id=0): super().__init__() @@ -261,7 +475,11 @@ class GPTNeoAttention(nn.Module): self.attention_type = self.attention_layers[layer_id] if self.attention_type in ["global", "local"]: - self.attention = GPTNeoSelfAttention(config, self.attention_type) + self.attention = ( + GPTNeoSelfAttention(config, self.attention_type) + if not getattr(config, "_flash_attn_2_enabled", False) + else GPTNeoFlashAttention2(config, self.attention_type) + ) else: raise NotImplementedError( "Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: " @@ -364,6 +582,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["GPTNeoBlock"] _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -524,10 +743,8 @@ class GPTNeoModel(GPTNeoPreTrainedModel): self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] else: raise ValueError("You have to specify either input_ids or inputs_embeds") @@ -546,26 +763,6 @@ class GPTNeoModel(GPTNeoPreTrainedModel): position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0) - # Attention mask. - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x num_heads x N x N @@ -577,6 +774,14 @@ class GPTNeoModel(GPTNeoPreTrainedModel): position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds + # Attention mask. + if getattr(self.config, "_flash_attn_2_enabled", False): + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, past_length) + if token_type_ids is not None: token_type_embeds = self.wte(token_type_ids) hidden_states = hidden_states + token_type_embeds