[core ] Integrate Flash attention 2 in most used models (#25598)
* v1 * oops * working v1 * fixup * add some TODOs * fixup * padding support + try with module replacement * nit * alternative design * oops * add `use_cache` support for llama * v1 falcon * nit * a bit of refactor * nit * nits nits * add v1 padding support falcon (even though it seemed to work before) * nit * falcon works * fixup * v1 tests * nit * fix generation llama flash * update tests * fix tests + nits * fix copies * fix nit * test- padding mask * stype * add more mem efficient support * Update src/transformers/modeling_utils.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * fixup * nit * fixup * remove it from config when saving * fixup * revert docstring * add more checks * use values * oops * new version * fixup * add same trick for falcon * nit * add another test * change tests * fix issues with GC and also falcon * fixup * oops * Update src/transformers/models/falcon/modeling_falcon.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add init_rope * updates * fix copies * fixup * fixup * more clarification * fixup * right padding tests * add docs * add FA in docker image * more clarifications * add some figures * add todo * rectify comment * Change to FA2 * Update docs/source/en/perf_infer_gpu_one.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * split in two lines * change test name * add more tests * some clean up * remove `rearrange` deps * add more docs * revert changes on dockerfile * Revert "revert changes on dockerfile" This reverts commit 8d72a66b4b9b771abc3f15a9b9506b4246d62d8e. * revert changes on dockerfile * Apply suggestions from code review Co-authored-by: Lysandre Debut <hi@lysand.re> * address some comments * docs * use inheritance * Update src/transformers/testing_utils.py Co-authored-by: Lysandre Debut <hi@lysand.re> * fixup * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/modeling_utils.py * final comments * clean up * style * add cast + warning for PEFT models * fixup --------- Co-authored-by: Felix Marty <9808326+fxmarty@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
@@ -60,6 +60,7 @@ from .utils import (
|
||||
is_detectron2_available,
|
||||
is_essentia_available,
|
||||
is_faiss_available,
|
||||
is_flash_attn_available,
|
||||
is_flax_available,
|
||||
is_fsdp_available,
|
||||
is_ftfy_available,
|
||||
@@ -392,6 +393,16 @@ def require_torch(test_case):
|
||||
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
|
||||
|
||||
|
||||
def require_flash_attn(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires Flash Attention.
|
||||
|
||||
These tests are skipped when Flash Attention isn't installed.
|
||||
|
||||
"""
|
||||
return unittest.skipUnless(is_flash_attn_available(), "test requires Flash Attention")(test_case)
|
||||
|
||||
|
||||
def require_peft(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PEFT.
|
||||
|
||||
Reference in New Issue
Block a user