@@ -332,10 +332,11 @@ In that case, you should see a warning message and we will fall back to the (slo
|
|||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
By default, SDPA selects the most performant kernel available but you can check whether a backend is available in a given setting (hardware, problem size) with [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager:
|
By default, SDPA selects the most performant kernel available but you can check whether a backend is available in a given setting (hardware, problem size) with [`torch.nn.attention.sdpa_kernel`](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) as a context manager:
|
||||||
|
|
||||||
```diff
|
```diff
|
||||||
import torch
|
import torch
|
||||||
|
+ from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||||
@@ -344,7 +345,7 @@ model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=to
|
|||||||
input_text = "Hello my dog is cute and"
|
input_text = "Hello my dog is cute and"
|
||||||
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
|
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||||
|
|
||||||
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
||||||
outputs = model.generate(**inputs)
|
outputs = model.generate(**inputs)
|
||||||
|
|
||||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||||
@@ -518,6 +519,7 @@ It is often possible to combine several of the optimization techniques described
|
|||||||
|
|
||||||
```py
|
```py
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||||||
|
|
||||||
# load model in 4-bit
|
# load model in 4-bit
|
||||||
@@ -536,7 +538,7 @@ input_text = "Hello my dog is cute and"
|
|||||||
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
|
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||||
|
|
||||||
# enable FlashAttention
|
# enable FlashAttention
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
||||||
outputs = model.generate(**inputs)
|
outputs = model.generate(**inputs)
|
||||||
|
|
||||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||||
|
|||||||
Reference in New Issue
Block a user