F.scaled_dot_product_attention support (#26572)

* add sdpa

* wip

* cleaning

* add ref

* yet more cleaning

* and more :)

* wip llama

* working llama

* add output_attentions=True support

* bigcode sdpa support

* fixes

* gpt-bigcode support, require torch>=2.1.1

* add falcon support

* fix conflicts falcon

* style

* fix attention_mask definition

* remove output_attentions from attnmaskconverter

* support whisper without removing any Copied from statement

* fix mbart default to eager renaming

* fix typo in falcon

* fix is_causal in SDPA

* check is_flash_attn_2_available in the models init as well in case the model is not initialized through from_pretrained

* add warnings when falling back on the manual implementation

* precise doc

* wip replace _flash_attn_enabled by config.attn_implementation

* fix typo

* add tests

* style

* add a copy.deepcopy on the config in from_pretrained, as we do not want to modify it inplace

* obey to config.attn_implementation if a config is passed in from_pretrained

* fix is_torch_sdpa_available when torch is not installed

* remove dead code

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/bart/modeling_bart.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* remove duplicate pretraining_tp code

* add dropout in llama

* precise comment on attn_mask

* add fmt: off for _unmask_unattended docstring

* precise num_masks comment

* nuke pretraining_tp in LlamaSDPAAttention following Arthur's suggestion

* cleanup modeling_utils

* backward compatibility

* fix style as requested

* style

* improve documentation

* test pass

* style

* add _unmask_unattended tests

* skip meaningless tests for idefics

* hard_check SDPA requirements when specifically requested

* standardize the use if XXX_ATTENTION_CLASSES

* fix SDPA bug with mem-efficient backend on CUDA when using fp32

* fix test

* rely on SDPA is_causal parameter to handle the causal mask in some cases

* fix FALCON_ATTENTION_CLASSES

* remove _flash_attn_2_enabled occurences

* fix test

* add OPT to the list of supported flash models

* improve test

* properly test on different SDPA backends, on different dtypes & properly handle separately the pad tokens in the test

* remove remaining _flash_attn_2_enabled occurence

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/perf_infer_gpu_one.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* remove use_attn_implementation

* fix docstring & slight bug

* make attn_implementation internal (_attn_implementation)

* typos

* fix tests

* deprecate use_flash_attention_2=True

* fix test

* add back llama that was removed by mistake

* fix tests

* remove _flash_attn_2_enabled occurences bis

* add check & test that passed attn_implementation is valid

* fix falcon torchscript export

* fix device of mask in tests

* add tip about torch.jit.trace and move bt doc below sdpa

* fix parameterized.expand order

* move tests from test_modeling_attn_mask_utils to test_modeling_utils as a relevant test class is already there

* update sdpaattention class with the new cache

* Update src/transformers/configuration_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/bark/modeling_bark.py

* address review comments

* WIP torch.jit.trace fix. left: test both eager & sdpa

* add test for torch.jit.trace for both eager/sdpa

* fix falcon with torch==2.0 that needs to use sdpa

* fix doc

* hopefully last fix

* fix key_value_length that has no default now in mask converter

* is it flacky?

* fix speculative decoding bug

* tests do pass

* fix following #27907

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
fxmarty
2023-12-08 21:38:14 +01:00
committed by GitHub
parent ce0bbd5101
commit 80377eb018
54 changed files with 2227 additions and 454 deletions

View File

@@ -12,11 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import doctest
import logging
import os
import unittest
from glob import glob
from pathlib import Path
from typing import List, Union
@@ -27,6 +27,63 @@ from transformers.testing_utils import require_tf, require_torch, slow
logger = logging.getLogger()
@require_torch
class TestDocLists(unittest.TestCase):
def test_flash_support_list(self):
with open("./docs/source/en/perf_infer_gpu_one.md", "r") as f:
doctext = f.read()
doctext = doctext.split("FlashAttention-2 is currently supported for the following architectures:")[1]
doctext = doctext.split("You can request to add FlashAttention-2 support")[0]
patterns = glob("./src/transformers/models/**/modeling_*.py")
patterns_tf = glob("./src/transformers/models/**/modeling_tf_*.py")
patterns_flax = glob("./src/transformers/models/**/modeling_flax_*.py")
patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax))
archs_supporting_fa2 = []
for filename in patterns:
with open(filename, "r") as f:
text = f.read()
if "_supports_flash_attn_2 = True" in text:
model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "")
archs_supporting_fa2.append(model_name)
for arch in archs_supporting_fa2:
if arch not in doctext:
raise ValueError(
f"{arch} should be in listed in the flash attention documentation but is not. Please update the documentation."
)
def test_sdpa_support_list(self):
with open("./docs/source/en/perf_infer_gpu_one.md", "r") as f:
doctext = f.read()
doctext = doctext.split(
"For now, Transformers supports inference and training through SDPA for the following architectures:"
)[1]
doctext = doctext.split("Note that FlashAttention can only be used for models using the")[0]
patterns = glob("./src/transformers/models/**/modeling_*.py")
patterns_tf = glob("./src/transformers/models/**/modeling_tf_*.py")
patterns_flax = glob("./src/transformers/models/**/modeling_flax_*.py")
patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax))
archs_supporting_sdpa = []
for filename in patterns:
with open(filename, "r") as f:
text = f.read()
if "_supports_sdpa = True" in text:
model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "")
archs_supporting_sdpa.append(model_name)
for arch in archs_supporting_sdpa:
if arch not in doctext:
raise ValueError(
f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation."
)
@unittest.skip("Temporarily disable the doc tests.")
@require_torch
@require_tf