[FA-2] Final fix for FA2 dtype (#26846)

* final fix for FA2 dtype

* try

* oops

* Update src/transformers/models/falcon/modeling_falcon.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* apply fix everywhere

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Younes Belkada
2023-10-18 19:48:55 +02:00
committed by GitHub
parent 732d2a8aac
commit 5a73316bed
4 changed files with 69 additions and 19 deletions

View File

@@ -64,6 +64,7 @@ from transformers.testing_utils import (
is_pt_flax_cross_test,
is_pt_tf_cross_test,
require_accelerate,
require_bitsandbytes,
require_flash_attn,
require_safetensors,
require_torch,
@@ -2959,6 +2960,45 @@ class ModelTesterMixin:
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=30, do_sample=False
)
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
@mark.flash_attn_test
@slow
def test_flash_attn_2_fp32_ln(self):
import torch
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
return
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
use_flash_attention_2=True,
low_cpu_mem_usage=True,
load_in_4bit=True,
)
for _, param in model.named_parameters():
# upcast only layer norms
if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
param.data = param.data.to(torch.float32)
_ = model(input_ids=dummy_input)
# with attention mask
_ = model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
global_rng = random.Random()