[FA-2] Final fix for FA2 dtype (#26846)
* final fix for FA2 dtype * try * oops * Update src/transformers/models/falcon/modeling_falcon.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * apply fix everywhere --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -613,15 +613,18 @@ class FalconFlashAttention2(FalconAttention):
|
|||||||
# cast them back in float16 just to be sure everything works as expected.
|
# cast them back in float16 just to be sure everything works as expected.
|
||||||
input_dtype = query_layer.dtype
|
input_dtype = query_layer.dtype
|
||||||
if input_dtype == torch.float32:
|
if input_dtype == torch.float32:
|
||||||
|
# Handle the case where the model is quantized
|
||||||
|
target_dtype = getattr(self.config, "_pre_quantization_dtype", self.query_key_value.weight.dtype)
|
||||||
|
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"The input hidden states seems to be silently casted in float32, this might be related to"
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||||
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||||
" float16."
|
f" {target_dtype}."
|
||||||
)
|
)
|
||||||
|
|
||||||
query_layer = query_layer.to(torch.float16)
|
query_layer = query_layer.to(target_dtype)
|
||||||
key_layer = key_layer.to(torch.float16)
|
key_layer = key_layer.to(target_dtype)
|
||||||
value_layer = value_layer.to(torch.float16)
|
value_layer = value_layer.to(target_dtype)
|
||||||
|
|
||||||
attn_output = self._flash_attention_forward(
|
attn_output = self._flash_attention_forward(
|
||||||
query_layer, key_layer, value_layer, padding_mask, query_length, dropout=attn_dropout
|
query_layer, key_layer, value_layer, padding_mask, query_length, dropout=attn_dropout
|
||||||
|
|||||||
@@ -469,20 +469,24 @@ class LlamaFlashAttention2(LlamaAttention):
|
|||||||
|
|
||||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
# cast them back in float16 just to be sure everything works as expected.
|
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||||
# in fp32. (LlamaRMSNorm handles it correctly)
|
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||||
|
|
||||||
input_dtype = query_states.dtype
|
input_dtype = query_states.dtype
|
||||||
if input_dtype == torch.float32:
|
if input_dtype == torch.float32:
|
||||||
|
# Handle the case where the model is quantized
|
||||||
|
target_dtype = getattr(self.config, "_pre_quantization_dtype", self.q_proj.weight.dtype)
|
||||||
|
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"The input hidden states seems to be silently casted in float32, this might be related to"
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||||
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||||
" float16."
|
f" {target_dtype}."
|
||||||
)
|
)
|
||||||
|
|
||||||
query_states = query_states.to(torch.float16)
|
query_states = query_states.to(target_dtype)
|
||||||
key_states = key_states.to(torch.float16)
|
key_states = key_states.to(target_dtype)
|
||||||
value_states = value_states.to(torch.float16)
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
attn_output = self._flash_attention_forward(
|
attn_output = self._flash_attention_forward(
|
||||||
query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate
|
query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate
|
||||||
|
|||||||
@@ -408,15 +408,18 @@ class MistralFlashAttention2(MistralAttention):
|
|||||||
# cast them back in float16 just to be sure everything works as expected.
|
# cast them back in float16 just to be sure everything works as expected.
|
||||||
input_dtype = query_states.dtype
|
input_dtype = query_states.dtype
|
||||||
if input_dtype == torch.float32:
|
if input_dtype == torch.float32:
|
||||||
|
# Handle the case where the model is quantized
|
||||||
|
target_dtype = getattr(self.config, "_pre_quantization_dtype", self.q_proj.weight.dtype)
|
||||||
|
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"The input hidden states seems to be silently casted in float32, this might be related to"
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||||
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||||
" float16."
|
f" {target_dtype}."
|
||||||
)
|
)
|
||||||
|
|
||||||
query_states = query_states.to(torch.float16)
|
query_states = query_states.to(target_dtype)
|
||||||
key_states = key_states.to(torch.float16)
|
key_states = key_states.to(target_dtype)
|
||||||
value_states = value_states.to(torch.float16)
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
# Reashape to the expected shape for Flash Attention
|
# Reashape to the expected shape for Flash Attention
|
||||||
query_states = query_states.transpose(1, 2)
|
query_states = query_states.transpose(1, 2)
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ from transformers.testing_utils import (
|
|||||||
is_pt_flax_cross_test,
|
is_pt_flax_cross_test,
|
||||||
is_pt_tf_cross_test,
|
is_pt_tf_cross_test,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
|
require_bitsandbytes,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_safetensors,
|
require_safetensors,
|
||||||
require_torch,
|
require_torch,
|
||||||
@@ -2959,6 +2960,45 @@ class ModelTesterMixin:
|
|||||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=30, do_sample=False
|
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=30, do_sample=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_flash_attn
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_bitsandbytes
|
||||||
|
@mark.flash_attn_test
|
||||||
|
@slow
|
||||||
|
def test_flash_attn_2_fp32_ln(self):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
if not model_class._supports_flash_attn_2:
|
||||||
|
return
|
||||||
|
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
||||||
|
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
|
||||||
|
|
||||||
|
model = model_class.from_pretrained(
|
||||||
|
tmpdirname,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
use_flash_attention_2=True,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
load_in_4bit=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, param in model.named_parameters():
|
||||||
|
# upcast only layer norms
|
||||||
|
if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
|
||||||
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
|
_ = model(input_ids=dummy_input)
|
||||||
|
|
||||||
|
# with attention mask
|
||||||
|
_ = model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
|
||||||
|
|
||||||
|
|
||||||
global_rng = random.Random()
|
global_rng = random.Random()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user