[whisper] compile compatibility with long-form decoding (#31772)

* [whisper] compile compatibility with long-form decoding

* clarify comment

* fix after rebase

* finalise

* fix bsz

* fix cache split

* remove contiguous

* style

* finish

* update doc

* prevent cuda graph trace
This commit is contained in:
Sanchit Gandhi
2024-08-01 18:10:56 +08:00
committed by GitHub
parent 9451a38526
commit e234061cdd
4 changed files with 156 additions and 15 deletions

View File

@@ -3386,6 +3386,66 @@ class WhisperModelIntegrationTests(unittest.TestCase):
# assert re-ordered generations match those from eager
assert (eager_generated_ids[permutation_idx, :] == static_generated_ids).all()
@slow
def test_tiny_static_generation_long_form(self):
import torch._dynamo.config
# only permit 4 compilations: 2 prefill steps and 2 decoding steps (1 for each of conditioned/not conditioned)
torch._dynamo.config.cache_size_limit = 4
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.to(torch_device)
dataset = load_dataset("distil-whisper/meanwhile", "default")["test"]
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
input_speech = [audio["array"] for audio in dataset[2:4]["audio"]]
inputs = processor(
input_speech,
return_tensors="pt",
padding="longest",
truncation=False,
return_attention_mask=True,
sampling_rate=16_000,
)
inputs = inputs.to(torch_device)
gen_kwargs = {
"return_timestamps": True,
"no_speech_threshold": 0.6,
"temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
"compression_ratio_threshold": 1.35,
"condition_on_prev_tokens": True, # conditioning on prev tokens introduces a recompile on the second time step
"logprob_threshold": -1.0,
"num_beams": 1,
}
set_seed(42)
eager_generated_ids = model.generate(**inputs, **gen_kwargs)
# compile the forward pass and assert equivalence
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
set_seed(42)
static_generated_ids = model.generate(**inputs, **gen_kwargs)
assert (eager_generated_ids == static_generated_ids).all()
# check the compiled graph can be re-used and that the cache is correctly reset
# reverse the ordering of the input features
input_features = inputs.input_features
permutation_idx = (
torch.arange(input_features.shape[0], 0, step=-1, dtype=torch.long, device=input_features.device) - 1
)
input_features = input_features[permutation_idx, ...]
attention_mask = inputs.attention_mask[permutation_idx, ...]
set_seed(42)
static_generated_ids = model.generate(input_features, attention_mask=attention_mask, **gen_kwargs)
# assert re-ordered generations match those from eager
assert (eager_generated_ids[permutation_idx, :] == static_generated_ids).all()
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
if head_mask is None: