[core ] Integrate Flash attention 2 in most used models (#25598)
* v1 * oops * working v1 * fixup * add some TODOs * fixup * padding support + try with module replacement * nit * alternative design * oops * add `use_cache` support for llama * v1 falcon * nit * a bit of refactor * nit * nits nits * add v1 padding support falcon (even though it seemed to work before) * nit * falcon works * fixup * v1 tests * nit * fix generation llama flash * update tests * fix tests + nits * fix copies * fix nit * test- padding mask * stype * add more mem efficient support * Update src/transformers/modeling_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * fixup * nit * fixup * remove it from config when saving * fixup * revert docstring * add more checks * use values * oops * new version * fixup * add same trick for falcon * nit * add another test * change tests * fix issues with GC and also falcon * fixup * oops * Update src/transformers/models/falcon/modeling_falcon.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add init_rope * updates * fix copies * fixup * fixup * more clarification * fixup * right padding tests * add docs * add FA in docker image * more clarifications * add some figures * add todo * rectify comment * Change to FA2 * Update docs/source/en/perf_infer_gpu_one.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * split in two lines * change test name * add more tests * some clean up * remove `rearrange` deps * add more docs * revert changes on dockerfile * Revert "revert changes on dockerfile" This reverts commit 8d72a66b4b9b771abc3f15a9b9506b4246d62d8e. * revert changes on dockerfile * Apply suggestions from code review Co-authored-by: Lysandre Debut <hi@lysand.re> * address some comments * docs * use inheritance * Update src/transformers/testing_utils.py Co-authored-by: Lysandre Debut <hi@lysand.re> * fixup * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/modeling_utils.py * final comments * clean up * style * add cast + warning for PEFT models * fixup --------- Co-authored-by: Felix Marty <9808326+fxmarty@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
@@ -18,9 +18,10 @@
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
from pytest import mark
|
||||
|
||||
from transformers import LlamaConfig, is_torch_available, set_seed
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
|
||||
from transformers.testing_utils import require_flash_attn, require_torch, require_torch_gpu, slow, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -375,6 +376,41 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
# The output should be different for long inputs
|
||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_generate_padding_right(self):
|
||||
"""
|
||||
Overwritting the common test as the test is flaky on tiny models
|
||||
"""
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
load_in_4bit=True,
|
||||
device_map={"": 0},
|
||||
)
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
|
||||
texts = ["hi", "Hello this is a very long sentence"]
|
||||
|
||||
tokenizer.padding_side = "right"
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
|
||||
|
||||
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_native = tokenizer.batch_decode(output_native)
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, use_flash_attention_2=True
|
||||
)
|
||||
|
||||
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_fa_2 = tokenizer.batch_decode(output_fa_2)
|
||||
|
||||
self.assertListEqual(output_native, output_fa_2)
|
||||
|
||||
|
||||
@require_torch
|
||||
class LlamaIntegrationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user