Update docs for sdpa_kernel (#35410)

update: sdp_kernel -> sdpa_kernel
This commit is contained in:
Jacky Lee
2024-12-30 09:50:34 -08:00
committed by GitHub
parent 5cabc75b4b
commit b5f97977ed

View File

@@ -332,10 +332,11 @@ In that case, you should see a warning message and we will fall back to the (slo
</Tip>
By default, SDPA selects the most performant kernel available but you can check whether a backend is available in a given setting (hardware, problem size) with [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager:
By default, SDPA selects the most performant kernel available but you can check whether a backend is available in a given setting (hardware, problem size) with [`torch.nn.attention.sdpa_kernel`](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) as a context manager:
```diff
import torch
+ from torch.nn.attention import SDPBackend, sdpa_kernel
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
@@ -344,7 +345,7 @@ model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=to
input_text = "Hello my dog is cute and"
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
outputs = model.generate(**inputs)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
@@ -518,6 +519,7 @@ It is often possible to combine several of the optimization techniques described
```py
import torch
from torch.nn.attention import SDPBackend, sdpa_kernel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# load model in 4-bit
@@ -536,7 +538,7 @@ input_text = "Hello my dog is cute and"
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
# enable FlashAttention
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
outputs = model.generate(**inputs)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))