[whisper] compile compatibility with long-form decoding (#31772)
* [whisper] compile compatibility with long-form decoding * clarify comment * fix after rebase * finalise * fix bsz * fix cache split * remove contiguous * style * finish * update doc * prevent cuda graph trace
This commit is contained in:
@@ -72,7 +72,7 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
|
||||
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
|
||||
```
|
||||
|
||||
Whisper is compatible with the following optimisations:
|
||||
Whisper is compatible with the following optimisations for both short and long-form generation:
|
||||
- [PyTorch Scaled Dot Product Attention (SDPA)](../perf_infer_gpu_one#pytorch-scaled-dot-product-attention): flash attention and memory-efficient attention kernels. Enabled by default for `torch>=2.1.1`.
|
||||
- [Flash Attention 2](../perf_infer_gpu_one#flashattention-2): improved implementation of flash attention through better parallelism and work partitioning.
|
||||
- [torch.compile](../llm_optims#static-kv-cache-and-torchcompile): JIT-compile the forward pass to dispatch to efficient fused kernels.
|
||||
@@ -101,7 +101,8 @@ As an example, the following codesnippet enables SDPA and `torch.compile` for up
|
||||
... ).input_features
|
||||
|
||||
>>> # Compile the forward pass
|
||||
>>> _ = model.generate(input_features)
|
||||
>>> for _ in range(2):
|
||||
>>> model.generate(input_features)
|
||||
|
||||
>>> # Generate token ids using compiled graph (fast!)
|
||||
>>> predicted_ids = model.generate(input_features)
|
||||
|
||||
Reference in New Issue
Block a user