[core ] Integrate Flash attention 2 in most used models (#25598)

* v1

* oops

* working v1

* fixup

* add some TODOs

* fixup

* padding support + try with module replacement

* nit

* alternative design

* oops

* add `use_cache` support for llama

* v1 falcon

* nit

* a bit of refactor

* nit

* nits nits

* add v1 padding support falcon (even though it seemed to work before)

* nit

* falcon works

* fixup

* v1 tests

* nit

* fix generation llama flash

* update tests

* fix tests + nits

* fix copies

* fix nit

* test- padding mask

* stype

* add more mem efficient support

* Update src/transformers/modeling_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* fixup

* nit

* fixup

* remove it from config when saving

* fixup

* revert docstring

* add more checks

* use values

* oops

* new version

* fixup

* add same trick for falcon

* nit

* add another test

* change tests

* fix issues with GC and also falcon

* fixup

* oops

* Update src/transformers/models/falcon/modeling_falcon.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* add init_rope

* updates

* fix copies

* fixup

* fixup

* more clarification

* fixup

* right padding tests

* add docs

* add FA in docker image

* more clarifications

* add some figures

* add todo

* rectify comment

* Change to FA2

* Update docs/source/en/perf_infer_gpu_one.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* split in two lines

* change test name

* add more tests

* some clean up

* remove `rearrange` deps

* add more docs

* revert changes on dockerfile

* Revert "revert changes on dockerfile"

This reverts commit 8d72a66b4b9b771abc3f15a9b9506b4246d62d8e.

* revert changes on dockerfile

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <hi@lysand.re>

* address some comments

* docs

* use inheritance

* Update src/transformers/testing_utils.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* fixup

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

* final comments

* clean up

* style

* add cast + warning for PEFT models

* fixup

---------

Co-authored-by: Felix Marty <9808326+fxmarty@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
Younes Belkada
2023-09-22 17:42:10 +02:00
committed by GitHub
parent dcbfd93d7a
commit 368a58e61c
14 changed files with 934 additions and 14 deletions

View File

@@ -18,9 +18,10 @@
import unittest
from parameterized import parameterized
from pytest import mark
from transformers import LlamaConfig, is_torch_available, set_seed
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
from transformers.testing_utils import require_flash_attn, require_torch, require_torch_gpu, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@@ -375,6 +376,41 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
def test_flash_attn_2_generate_padding_right(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
load_in_4bit=True,
device_map={"": 0},
)
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
texts = ["hi", "Hello this is a very long sentence"]
tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_native = tokenizer.batch_decode(output_native)
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True
)
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_fa_2 = tokenizer.batch_decode(output_fa_2)
self.assertListEqual(output_native, output_fa_2)
@require_torch
class LlamaIntegrationTest(unittest.TestCase):