[FA-2] Final fix for FA2 dtype (#26846)

* final fix for FA2 dtype

* try

* oops

* Update src/transformers/models/falcon/modeling_falcon.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* apply fix everywhere

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Younes Belkada
2023-10-18 19:48:55 +02:00
committed by GitHub
parent 732d2a8aac
commit 5a73316bed
4 changed files with 69 additions and 19 deletions

View File

@@ -613,15 +613,18 @@ class FalconFlashAttention2(FalconAttention):
# cast them back in float16 just to be sure everything works as expected. # cast them back in float16 just to be sure everything works as expected.
input_dtype = query_layer.dtype input_dtype = query_layer.dtype
if input_dtype == torch.float32: if input_dtype == torch.float32:
# Handle the case where the model is quantized
target_dtype = getattr(self.config, "_pre_quantization_dtype", self.query_key_value.weight.dtype)
logger.warning_once( logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to" f"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16." f" {target_dtype}."
) )
query_layer = query_layer.to(torch.float16) query_layer = query_layer.to(target_dtype)
key_layer = key_layer.to(torch.float16) key_layer = key_layer.to(target_dtype)
value_layer = value_layer.to(torch.float16) value_layer = value_layer.to(target_dtype)
attn_output = self._flash_attention_forward( attn_output = self._flash_attention_forward(
query_layer, key_layer, value_layer, padding_mask, query_length, dropout=attn_dropout query_layer, key_layer, value_layer, padding_mask, query_length, dropout=attn_dropout

View File

@@ -469,20 +469,24 @@ class LlamaFlashAttention2(LlamaAttention):
# In PEFT, usually we cast the layer norms in float32 for training stability reasons # In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need # therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected. # cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms # This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly) # in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype input_dtype = query_states.dtype
if input_dtype == torch.float32: if input_dtype == torch.float32:
# Handle the case where the model is quantized
target_dtype = getattr(self.config, "_pre_quantization_dtype", self.q_proj.weight.dtype)
logger.warning_once( logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to" f"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16." f" {target_dtype}."
) )
query_states = query_states.to(torch.float16) query_states = query_states.to(target_dtype)
key_states = key_states.to(torch.float16) key_states = key_states.to(target_dtype)
value_states = value_states.to(torch.float16) value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward( attn_output = self._flash_attention_forward(
query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate

View File

@@ -408,15 +408,18 @@ class MistralFlashAttention2(MistralAttention):
# cast them back in float16 just to be sure everything works as expected. # cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype input_dtype = query_states.dtype
if input_dtype == torch.float32: if input_dtype == torch.float32:
# Handle the case where the model is quantized
target_dtype = getattr(self.config, "_pre_quantization_dtype", self.q_proj.weight.dtype)
logger.warning_once( logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to" f"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16." f" {target_dtype}."
) )
query_states = query_states.to(torch.float16) query_states = query_states.to(target_dtype)
key_states = key_states.to(torch.float16) key_states = key_states.to(target_dtype)
value_states = value_states.to(torch.float16) value_states = value_states.to(target_dtype)
# Reashape to the expected shape for Flash Attention # Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2) query_states = query_states.transpose(1, 2)

View File

@@ -64,6 +64,7 @@ from transformers.testing_utils import (
is_pt_flax_cross_test, is_pt_flax_cross_test,
is_pt_tf_cross_test, is_pt_tf_cross_test,
require_accelerate, require_accelerate,
require_bitsandbytes,
require_flash_attn, require_flash_attn,
require_safetensors, require_safetensors,
require_torch, require_torch,
@@ -2959,6 +2960,45 @@ class ModelTesterMixin:
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=30, do_sample=False dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=30, do_sample=False
) )
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
@mark.flash_attn_test
@slow
def test_flash_attn_2_fp32_ln(self):
import torch
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
return
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
use_flash_attention_2=True,
low_cpu_mem_usage=True,
load_in_4bit=True,
)
for _, param in model.named_parameters():
# upcast only layer norms
if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
param.data = param.data.to(torch.float32)
_ = model(input_ids=dummy_input)
# with attention mask
_ = model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
global_rng = random.Random() global_rng = random.Random()