Update llm_optims docs for sdpa_kernel (#35481)
update: use sdpa_kernel
This commit is contained in:
@@ -156,9 +156,11 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu
|
|||||||
There are a few important things you must do to enable static kv-cache and `torch.compile` with the `StaticCache` method:
|
There are a few important things you must do to enable static kv-cache and `torch.compile` with the `StaticCache` method:
|
||||||
1. Initialize the [`StaticCache`] instance before using the model for inference. There you can configure parameters like the maximum batch size and sequence length.
|
1. Initialize the [`StaticCache`] instance before using the model for inference. There you can configure parameters like the maximum batch size and sequence length.
|
||||||
2. Call `torch.compile` on the model to compile the forward pass with the static kv-cache.
|
2. Call `torch.compile` on the model to compile the forward pass with the static kv-cache.
|
||||||
3. Set `enable_math=True` in the [torch.backends.cuda.sdp_kernel](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) context manager to enable the native PyTorch C++ implementation of scaled dot product attention to speed up inference even more.
|
3. Use `SDPBackend.MATH` in the [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) context manager to enable the native PyTorch C++ implementation of scaled dot product attention to speed up inference even more.
|
||||||
|
|
||||||
```py
|
```py
|
||||||
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||||
|
|
||||||
batch_size, seq_length = inputs["input_ids"].shape
|
batch_size, seq_length = inputs["input_ids"].shape
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
past_key_values = StaticCache(
|
past_key_values = StaticCache(
|
||||||
@@ -179,7 +181,7 @@ with torch.no_grad():
|
|||||||
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
|
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
|
||||||
cache_position = torch.tensor([seq_length + 1], device=torch_device)
|
cache_position = torch.tensor([seq_length + 1], device=torch_device)
|
||||||
for _ in range(1, NUM_TOKENS_TO_GENERATE):
|
for _ in range(1, NUM_TOKENS_TO_GENERATE):
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
with sdpa_kernel(SDPBackend.MATH):
|
||||||
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values)
|
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values)
|
||||||
generated_ids[:, cache_position] = next_token.int()
|
generated_ids[:, cache_position] = next_token.int()
|
||||||
cache_position += 1
|
cache_position += 1
|
||||||
@@ -453,10 +455,11 @@ Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and
|
|||||||
> [!TIP]
|
> [!TIP]
|
||||||
> SDPA supports FlashAttention-2 as long as you have the latest PyTorch version installed.
|
> SDPA supports FlashAttention-2 as long as you have the latest PyTorch version installed.
|
||||||
|
|
||||||
Use the [torch.backends.cuda.sdp_kernel](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) context manager to explicitly enable or disable any of the three attention algorithms. For example, set `enable_flash=True` to enable FlashAttention.
|
Use the [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) context manager to explicitly enable or disable any of the four attention algorithms. For example, use `SDPBackend.FLASH_ATTENTION` to enable FlashAttention.
|
||||||
|
|
||||||
```py
|
```py
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
@@ -464,7 +467,7 @@ model = AutoModelForCausalLM.from_pretrained(
|
|||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
)
|
)
|
||||||
|
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
||||||
outputs = model.generate(**inputs)
|
outputs = model.generate(**inputs)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user