F.scaled_dot_product_attention support (#26572)

* add sdpa

* wip

* cleaning

* add ref

* yet more cleaning

* and more :)

* wip llama

* working llama

* add output_attentions=True support

* bigcode sdpa support

* fixes

* gpt-bigcode support, require torch>=2.1.1

* add falcon support

* fix conflicts falcon

* style

* fix attention_mask definition

* remove output_attentions from attnmaskconverter

* support whisper without removing any Copied from statement

* fix mbart default to eager renaming

* fix typo in falcon

* fix is_causal in SDPA

* check is_flash_attn_2_available in the models init as well in case the model is not initialized through from_pretrained

* add warnings when falling back on the manual implementation

* precise doc

* wip replace _flash_attn_enabled by config.attn_implementation

* fix typo

* add tests

* style

* add a copy.deepcopy on the config in from_pretrained, as we do not want to modify it inplace

* obey to config.attn_implementation if a config is passed in from_pretrained

* fix is_torch_sdpa_available when torch is not installed

* remove dead code

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/bart/modeling_bart.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* remove duplicate pretraining_tp code

* add dropout in llama

* precise comment on attn_mask

* add fmt: off for _unmask_unattended docstring

* precise num_masks comment

* nuke pretraining_tp in LlamaSDPAAttention following Arthur's suggestion

* cleanup modeling_utils

* backward compatibility

* fix style as requested

* style

* improve documentation

* test pass

* style

* add _unmask_unattended tests

* skip meaningless tests for idefics

* hard_check SDPA requirements when specifically requested

* standardize the use if XXX_ATTENTION_CLASSES

* fix SDPA bug with mem-efficient backend on CUDA when using fp32

* fix test

* rely on SDPA is_causal parameter to handle the causal mask in some cases

* fix FALCON_ATTENTION_CLASSES

* remove _flash_attn_2_enabled occurences

* fix test

* add OPT to the list of supported flash models

* improve test

* properly test on different SDPA backends, on different dtypes & properly handle separately the pad tokens in the test

* remove remaining _flash_attn_2_enabled occurence

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/perf_infer_gpu_one.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* remove use_attn_implementation

* fix docstring & slight bug

* make attn_implementation internal (_attn_implementation)

* typos

* fix tests

* deprecate use_flash_attention_2=True

* fix test

* add back llama that was removed by mistake

* fix tests

* remove _flash_attn_2_enabled occurences bis

* add check & test that passed attn_implementation is valid

* fix falcon torchscript export

* fix device of mask in tests

* add tip about torch.jit.trace and move bt doc below sdpa

* fix parameterized.expand order

* move tests from test_modeling_attn_mask_utils to test_modeling_utils as a relevant test class is already there

* update sdpaattention class with the new cache

* Update src/transformers/configuration_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/bark/modeling_bark.py

* address review comments

* WIP torch.jit.trace fix. left: test both eager & sdpa

* add test for torch.jit.trace for both eager/sdpa

* fix falcon with torch==2.0 that needs to use sdpa

* fix doc

* hopefully last fix

* fix key_value_length that has no default now in mask converter

* is it flacky?

* fix speculative decoding bug

* tests do pass

* fix following #27907

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
fxmarty
2023-12-08 21:38:14 +01:00
committed by GitHub
parent ce0bbd5101
commit 80377eb018
54 changed files with 2227 additions and 454 deletions

View File

@@ -83,10 +83,10 @@ pip install -U flash-attn --no-build-isolation
##### Usage
To load a model using Flash Attention 2, we can pass the `use_flash_attention_2` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
To load a model using Flash Attention 2, we can pass the `attn_implementation="flash_attention_2"` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
```python
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, use_flash_attention_2=True).to(device)
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
```
##### Performance comparison
@@ -114,7 +114,7 @@ import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
# load in fp16 and use Flash Attention 2
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, use_flash_attention_2=True).to(device)
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
# enable CPU offload
model.enable_cpu_offload()

View File

@@ -153,7 +153,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> device = "cuda" # the device to load the model onto
>>> tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
>>> model = AutoModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> model = AutoModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> text = "Replace me by any text you'd like."

View File

@@ -59,7 +59,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto
>>> model = AutoModelForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> model = AutoModelForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = AutoTokenizer.from_pretrained("bigcode/gpt_bigcode-santacoder")
>>> prompt = "def hello_world():"

View File

@@ -67,7 +67,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto
>>> model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
>>> prompt = "def hello_world():"

View File

@@ -77,12 +77,12 @@ pip install -U flash-attn --no-build-isolation
### Usage
To load a model using Flash Attention 2, we can pass the `use_flash_attention_2` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
```python
>>> from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast
model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", torch_dtype=torch.float16, use_flash_attention_2=True).to(device)
model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
...
```

View File

@@ -99,7 +99,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto
>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
>>> prompt = "My favourite condiment is"

View File

@@ -80,7 +80,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import OPTForCausalLM, GPT2Tokenizer
>>> device = "cuda" # the device to load the model onto
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
>>> prompt = ("A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the "

View File

@@ -111,7 +111,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import PhiForCausalLM, AutoTokenizer
>>> # define the model and tokenizer and push the model and tokens to the GPU.
>>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev", torch_dtype=torch.float16, use_flash_attention_2=True).to("cuda")
>>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to("cuda")
>>> tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev")
>>> # feel free to change the prompt to your liking.
@@ -163,4 +163,4 @@ Below is an expected speedup diagram that compares pure inference time between t
- forward
</pt>
</frameworkcontent>
</frameworkcontent>