[core ] Integrate Flash attention 2 in most used models (#25598)
* v1 * oops * working v1 * fixup * add some TODOs * fixup * padding support + try with module replacement * nit * alternative design * oops * add `use_cache` support for llama * v1 falcon * nit * a bit of refactor * nit * nits nits * add v1 padding support falcon (even though it seemed to work before) * nit * falcon works * fixup * v1 tests * nit * fix generation llama flash * update tests * fix tests + nits * fix copies * fix nit * test- padding mask * stype * add more mem efficient support * Update src/transformers/modeling_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * fixup * nit * fixup * remove it from config when saving * fixup * revert docstring * add more checks * use values * oops * new version * fixup * add same trick for falcon * nit * add another test * change tests * fix issues with GC and also falcon * fixup * oops * Update src/transformers/models/falcon/modeling_falcon.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add init_rope * updates * fix copies * fixup * fixup * more clarification * fixup * right padding tests * add docs * add FA in docker image * more clarifications * add some figures * add todo * rectify comment * Change to FA2 * Update docs/source/en/perf_infer_gpu_one.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * split in two lines * change test name * add more tests * some clean up * remove `rearrange` deps * add more docs * revert changes on dockerfile * Revert "revert changes on dockerfile" This reverts commit 8d72a66b4b9b771abc3f15a9b9506b4246d62d8e. * revert changes on dockerfile * Apply suggestions from code review Co-authored-by: Lysandre Debut <hi@lysand.re> * address some comments * docs * use inheritance * Update src/transformers/testing_utils.py Co-authored-by: Lysandre Debut <hi@lysand.re> * fixup * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/modeling_utils.py * final comments * clean up * style * add cast + warning for PEFT models * fixup --------- Co-authored-by: Felix Marty <9808326+fxmarty@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
@@ -64,6 +64,7 @@ from transformers.testing_utils import (
|
||||
is_pt_flax_cross_test,
|
||||
is_pt_tf_cross_test,
|
||||
require_accelerate,
|
||||
require_flash_attn,
|
||||
require_safetensors,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
@@ -2722,6 +2723,191 @@ class ModelTesterMixin:
|
||||
num_params < 1000000
|
||||
), f"{model_class} is too big for the common tests ({num_params})! It should have 1M max."
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_conversion(self):
|
||||
import torch
|
||||
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
return
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True
|
||||
).to(torch_device)
|
||||
|
||||
for _, module in model.named_modules():
|
||||
if "FlashAttention" in module.__class__.__name__:
|
||||
return
|
||||
|
||||
self.assertTrue(False, "FlashAttention2 modules not found in model")
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_inference(self):
|
||||
import torch
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
return
|
||||
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
|
||||
)
|
||||
model.to(torch_device)
|
||||
|
||||
dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]]).to(torch_device)
|
||||
dummy_attention_mask = torch.LongTensor([[0, 1, 1, 1, 1]]).to(torch_device)
|
||||
|
||||
logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
|
||||
logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1]
|
||||
|
||||
self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2))
|
||||
|
||||
output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
|
||||
logits_fa = output_fa.hidden_states[-1]
|
||||
|
||||
output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
|
||||
logits = output.hidden_states[-1]
|
||||
|
||||
self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2))
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_inference_padding_right(self):
|
||||
import torch
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
return
|
||||
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
|
||||
)
|
||||
model.to(torch_device)
|
||||
|
||||
dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]]).to(torch_device)
|
||||
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1, 0]]).to(torch_device)
|
||||
|
||||
logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
|
||||
logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1]
|
||||
|
||||
self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2))
|
||||
|
||||
output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
|
||||
logits_fa = output_fa.hidden_states[-1]
|
||||
|
||||
output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
|
||||
logits = output.hidden_states[-1]
|
||||
|
||||
self.assertTrue(torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2))
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_generate_left_padding(self):
|
||||
import torch
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
return
|
||||
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True
|
||||
).to(torch_device)
|
||||
|
||||
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
||||
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
|
||||
|
||||
out = model.generate(
|
||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
||||
)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
|
||||
).to(torch_device)
|
||||
|
||||
out_fa = model.generate(
|
||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
||||
)
|
||||
|
||||
self.assertTrue(torch.equal(out, out_fa))
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_generate_padding_right(self):
|
||||
import torch
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
return
|
||||
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True
|
||||
).to(torch_device)
|
||||
|
||||
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
||||
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
|
||||
|
||||
out = model.generate(
|
||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
||||
)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
|
||||
).to(torch_device)
|
||||
|
||||
out_fa = model.generate(
|
||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
||||
)
|
||||
|
||||
self.assertTrue(torch.equal(out, out_fa))
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user