Revert "add attention_mask and position_ids in assisted model" (#27523)
* Revert "add attention_mask and position_ids in assisted model (#26892)"
This reverts commit 184f60dcec.
* more debug
This commit is contained in:
committed by
GitHub
parent
4989e73e2f
commit
5603fad247
@@ -18,6 +18,7 @@ import copy
|
||||
import inspect
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -1736,6 +1737,102 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
self.assertTrue(prompt in text)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_speculative_decoding_distil(self):
|
||||
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
||||
model_id = "openai/whisper-large-v2"
|
||||
model = WhisperForConditionalGeneration.from_pretrained(
|
||||
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
|
||||
)
|
||||
model.to(torch_device)
|
||||
|
||||
processor = WhisperProcessor.from_pretrained(model_id)
|
||||
|
||||
assistant_model_id = "distil-whisper/distil-large-v2"
|
||||
assistant_model = WhisperForCausalLM.from_pretrained(
|
||||
assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
|
||||
)
|
||||
assistant_model.to(torch_device)
|
||||
|
||||
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
sample = dataset[0]["audio"]
|
||||
|
||||
input_features = processor(sample["array"], return_tensors="pt").input_features.to("cuda").to(torch.float16)
|
||||
|
||||
# warm up assisted decoding
|
||||
_ = model.generate(input_features, assistant_model=assistant_model)
|
||||
# warm up non-assisted decoding
|
||||
_ = model.generate(input_features)
|
||||
|
||||
# assisted decoding
|
||||
start_time = time.time()
|
||||
tokens = model.generate(input_features, assistant_model=assistant_model)
|
||||
total_time_assist = time.time() - start_time
|
||||
|
||||
transcription_ass = processor.batch_decode(tokens, skip_special_tokens=True)
|
||||
|
||||
# non-assisted decoding
|
||||
start_time = time.time()
|
||||
tokens = model.generate(input_features)
|
||||
total_time_non_assist = time.time() - start_time
|
||||
|
||||
transcription_non_ass = processor.batch_decode(tokens, skip_special_tokens=True)
|
||||
|
||||
assert transcription_ass == transcription_non_ass
|
||||
assert transcription_ass == [
|
||||
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
|
||||
]
|
||||
assert total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster"
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_speculative_decoding_non_distil(self):
|
||||
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
||||
model_id = "openai/whisper-large-v2"
|
||||
model = WhisperForConditionalGeneration.from_pretrained(
|
||||
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
|
||||
)
|
||||
model.to(torch_device)
|
||||
|
||||
processor = WhisperProcessor.from_pretrained(model_id)
|
||||
|
||||
assistant_model_id = "openai/whisper-tiny"
|
||||
assistant_model = WhisperForConditionalGeneration.from_pretrained(
|
||||
assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
|
||||
)
|
||||
assistant_model.to(torch_device)
|
||||
|
||||
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
sample = dataset[0]["audio"]
|
||||
|
||||
input_features = processor(sample["array"], return_tensors="pt").input_features.to("cuda").to(torch.float16)
|
||||
|
||||
# warm up assisted decoding
|
||||
_ = model.generate(input_features, assistant_model=assistant_model)
|
||||
# warm up non-assisted decoding
|
||||
_ = model.generate(input_features)
|
||||
|
||||
# assisted decoding
|
||||
start_time = time.time()
|
||||
tokens = model.generate(input_features, assistant_model=assistant_model)
|
||||
total_time_assist = time.time() - start_time
|
||||
|
||||
transcription_ass = processor.batch_decode(tokens, skip_special_tokens=True)
|
||||
|
||||
# non-assisted decoding
|
||||
start_time = time.time()
|
||||
tokens = model.generate(input_features)
|
||||
total_time_non_assist = time.time() - start_time
|
||||
|
||||
transcription_non_ass = processor.batch_decode(tokens, skip_special_tokens=True)
|
||||
|
||||
assert transcription_ass == transcription_non_ass
|
||||
assert transcription_ass == [
|
||||
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
|
||||
]
|
||||
assert total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster"
|
||||
|
||||
|
||||
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
|
||||
if head_mask is None:
|
||||
|
||||
Reference in New Issue
Block a user