[whisper] compile compatibility with long-form decoding (#31772)

* [whisper] compile compatibility with long-form decoding

* clarify comment

* fix after rebase

* finalise

* fix bsz

* fix cache split

* remove contiguous

* style

* finish

* update doc

* prevent cuda graph trace
This commit is contained in:
Sanchit Gandhi
2024-08-01 18:10:56 +08:00
committed by GitHub
parent 9451a38526
commit e234061cdd
4 changed files with 156 additions and 15 deletions

View File

@@ -72,7 +72,7 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
```
Whisper is compatible with the following optimisations:
Whisper is compatible with the following optimisations for both short and long-form generation:
- [PyTorch Scaled Dot Product Attention (SDPA)](../perf_infer_gpu_one#pytorch-scaled-dot-product-attention): flash attention and memory-efficient attention kernels. Enabled by default for `torch>=2.1.1`.
- [Flash Attention 2](../perf_infer_gpu_one#flashattention-2): improved implementation of flash attention through better parallelism and work partitioning.
- [torch.compile](../llm_optims#static-kv-cache-and-torchcompile): JIT-compile the forward pass to dispatch to efficient fused kernels.
@@ -101,7 +101,8 @@ As an example, the following codesnippet enables SDPA and `torch.compile` for up
... ).input_features
>>> # Compile the forward pass
>>> _ = model.generate(input_features)
>>> for _ in range(2):
>>> model.generate(input_features)
>>> # Generate token ids using compiled graph (fast!)
>>> predicted_ids = model.generate(input_features)