[whisper] compile compatibility with long-form decoding (#31772)
* [whisper] compile compatibility with long-form decoding * clarify comment * fix after rebase * finalise * fix bsz * fix cache split * remove contiguous * style * finish * update doc * prevent cuda graph trace
This commit is contained in:
@@ -3386,6 +3386,66 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
# assert re-ordered generations match those from eager
|
||||
assert (eager_generated_ids[permutation_idx, :] == static_generated_ids).all()
|
||||
|
||||
@slow
|
||||
def test_tiny_static_generation_long_form(self):
|
||||
import torch._dynamo.config
|
||||
|
||||
# only permit 4 compilations: 2 prefill steps and 2 decoding steps (1 for each of conditioned/not conditioned)
|
||||
torch._dynamo.config.cache_size_limit = 4
|
||||
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||
model.to(torch_device)
|
||||
|
||||
dataset = load_dataset("distil-whisper/meanwhile", "default")["test"]
|
||||
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
|
||||
input_speech = [audio["array"] for audio in dataset[2:4]["audio"]]
|
||||
|
||||
inputs = processor(
|
||||
input_speech,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
truncation=False,
|
||||
return_attention_mask=True,
|
||||
sampling_rate=16_000,
|
||||
)
|
||||
inputs = inputs.to(torch_device)
|
||||
|
||||
gen_kwargs = {
|
||||
"return_timestamps": True,
|
||||
"no_speech_threshold": 0.6,
|
||||
"temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||
"compression_ratio_threshold": 1.35,
|
||||
"condition_on_prev_tokens": True, # conditioning on prev tokens introduces a recompile on the second time step
|
||||
"logprob_threshold": -1.0,
|
||||
"num_beams": 1,
|
||||
}
|
||||
|
||||
set_seed(42)
|
||||
eager_generated_ids = model.generate(**inputs, **gen_kwargs)
|
||||
|
||||
# compile the forward pass and assert equivalence
|
||||
model.generation_config.cache_implementation = "static"
|
||||
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
set_seed(42)
|
||||
static_generated_ids = model.generate(**inputs, **gen_kwargs)
|
||||
assert (eager_generated_ids == static_generated_ids).all()
|
||||
|
||||
# check the compiled graph can be re-used and that the cache is correctly reset
|
||||
# reverse the ordering of the input features
|
||||
input_features = inputs.input_features
|
||||
permutation_idx = (
|
||||
torch.arange(input_features.shape[0], 0, step=-1, dtype=torch.long, device=input_features.device) - 1
|
||||
)
|
||||
input_features = input_features[permutation_idx, ...]
|
||||
attention_mask = inputs.attention_mask[permutation_idx, ...]
|
||||
|
||||
set_seed(42)
|
||||
static_generated_ids = model.generate(input_features, attention_mask=attention_mask, **gen_kwargs)
|
||||
# assert re-ordered generations match those from eager
|
||||
assert (eager_generated_ids[permutation_idx, :] == static_generated_ids).all()
|
||||
|
||||
|
||||
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
|
||||
if head_mask is None:
|
||||
|
||||
Reference in New Issue
Block a user