[core ] Integrate Flash attention 2 in most used models (#25598)

* v1

* oops

* working v1

* fixup

* add some TODOs

* fixup

* padding support + try with module replacement

* nit

* alternative design

* oops

* add `use_cache` support for llama

* v1 falcon

* nit

* a bit of refactor

* nit

* nits nits

* add v1 padding support falcon (even though it seemed to work before)

* nit

* falcon works

* fixup

* v1 tests

* nit

* fix generation llama flash

* update tests

* fix tests + nits

* fix copies

* fix nit

* test- padding mask

* stype

* add more mem efficient support

* Update src/transformers/modeling_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* fixup

* nit

* fixup

* remove it from config when saving

* fixup

* revert docstring

* add more checks

* use values

* oops

* new version

* fixup

* add same trick for falcon

* nit

* add another test

* change tests

* fix issues with GC and also falcon

* fixup

* oops

* Update src/transformers/models/falcon/modeling_falcon.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* add init_rope

* updates

* fix copies

* fixup

* fixup

* more clarification

* fixup

* right padding tests

* add docs

* add FA in docker image

* more clarifications

* add some figures

* add todo

* rectify comment

* Change to FA2

* Update docs/source/en/perf_infer_gpu_one.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* split in two lines

* change test name

* add more tests

* some clean up

* remove `rearrange` deps

* add more docs

* revert changes on dockerfile

* Revert "revert changes on dockerfile"

This reverts commit 8d72a66b4b9b771abc3f15a9b9506b4246d62d8e.

* revert changes on dockerfile

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <hi@lysand.re>

* address some comments

* docs

* use inheritance

* Update src/transformers/testing_utils.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* fixup

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

* final comments

* clean up

* style

* add cast + warning for PEFT models

* fixup

---------

Co-authored-by: Felix Marty <9808326+fxmarty@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
Younes Belkada
2023-09-22 17:42:10 +02:00
committed by GitHub
parent dcbfd93d7a
commit 368a58e61c
14 changed files with 934 additions and 14 deletions

View File

@@ -64,6 +64,7 @@ from transformers.testing_utils import (
is_pt_flax_cross_test,
is_pt_tf_cross_test,
require_accelerate,
require_flash_attn,
require_safetensors,
require_torch,
require_torch_gpu,
@@ -2722,6 +2723,191 @@ class ModelTesterMixin:
num_params < 1000000
), f"{model_class} is too big for the common tests ({num_params})! It should have 1M max."
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_conversion(self):
import torch
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
return
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True
).to(torch_device)
for _, module in model.named_modules():
if "FlashAttention" in module.__class__.__name__:
return
self.assertTrue(False, "FlashAttention2 modules not found in model")
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference(self):
import torch
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
return
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
)
model.to(torch_device)
dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[0, 1, 1, 1, 1]]).to(torch_device)
logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1]
self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2))
output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
logits_fa = output_fa.hidden_states[-1]
output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
logits = output.hidden_states[-1]
self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_inference_padding_right(self):
import torch
for model_class in self.all_model_classes:
if not model_class._supports_flash_attn_2:
return
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True
)
model_fa.to(torch_device)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False
)
model.to(torch_device)
dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1, 0]]).to(torch_device)
logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1]
self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2))
output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
logits_fa = output_fa.hidden_states[-1]
output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
logits = output.hidden_states[-1]
self.assertTrue(torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_left_padding(self):
import torch
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
return
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True
).to(torch_device)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
).to(torch_device)
out_fa = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
self.assertTrue(torch.equal(out, out_fa))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):
import torch
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
return
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True
).to(torch_device)
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
out = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
model = model_class.from_pretrained(
tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True
).to(torch_device)
out_fa = model.generate(
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
self.assertTrue(torch.equal(out, out_fa))
global_rng = random.Random()