Using assistant in AutomaticSpeechRecognitionPipeline with different encoder size (#30637)
* fiw input to generate in pipeline * fixup * pass input_features to generate with assistant * error if model and assistant with different enc size * fix * apply review suggestions * use self.config.is_encoder_decoder * pass inputs to generate directly * add slow tests * Update src/transformers/generation/utils.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update tests/pipelines/test_pipelines_automatic_speech_recognition.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update tests/pipelines/test_pipelines_automatic_speech_recognition.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update tests/pipelines/test_pipelines_automatic_speech_recognition.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update tests/pipelines/test_pipelines_automatic_speech_recognition.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update tests/pipelines/test_pipelines_automatic_speech_recognition.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * apply review * Update src/transformers/generation/utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/pipelines/test_pipelines_automatic_speech_recognition.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * apply code review * update attributes encoder_xyz to check * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * add slow test * solve conflicts --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -23,6 +24,8 @@ from transformers import (
|
||||
MODEL_FOR_CTC_MAPPING,
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
AutoFeatureExtractor,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSpeechSeq2Seq,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
Speech2TextForConditionalGeneration,
|
||||
@@ -1138,6 +1141,94 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
{"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."},
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_speculative_decoding_whisper_non_distil(self):
|
||||
# Load data:
|
||||
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]")
|
||||
sample = dataset[0]["audio"]
|
||||
|
||||
# Load model:
|
||||
model_id = "openai/whisper-large-v2"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
model_id,
|
||||
use_safetensors=True,
|
||||
)
|
||||
|
||||
# Load assistant:
|
||||
assistant_model_id = "openai/whisper-tiny"
|
||||
assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
assistant_model_id,
|
||||
use_safetensors=True,
|
||||
)
|
||||
|
||||
# Load pipeline:
|
||||
pipe = AutomaticSpeechRecognitionPipeline(
|
||||
model=model,
|
||||
tokenizer=processor.tokenizer,
|
||||
feature_extractor=processor.feature_extractor,
|
||||
generate_kwargs={"language": "en"},
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
transcription_non_ass = pipe(sample.copy(), generate_kwargs={"assistant_model": assistant_model})["text"]
|
||||
total_time_assist = time.time() - start_time
|
||||
|
||||
start_time = time.time()
|
||||
transcription_ass = pipe(sample)["text"]
|
||||
total_time_non_assist = time.time() - start_time
|
||||
|
||||
self.assertEqual(transcription_ass, transcription_non_ass)
|
||||
self.assertEqual(
|
||||
transcription_ass,
|
||||
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
|
||||
)
|
||||
self.assertTrue(total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster")
|
||||
|
||||
@slow
|
||||
def test_speculative_decoding_whisper_distil(self):
|
||||
# Load data:
|
||||
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]")
|
||||
sample = dataset[0]["audio"]
|
||||
|
||||
# Load model:
|
||||
model_id = "openai/whisper-large-v2"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
model_id,
|
||||
use_safetensors=True,
|
||||
)
|
||||
|
||||
# Load assistant:
|
||||
assistant_model_id = "distil-whisper/distil-large-v2"
|
||||
assistant_model = AutoModelForCausalLM.from_pretrained(
|
||||
assistant_model_id,
|
||||
use_safetensors=True,
|
||||
)
|
||||
|
||||
# Load pipeline:
|
||||
pipe = AutomaticSpeechRecognitionPipeline(
|
||||
model=model,
|
||||
tokenizer=processor.tokenizer,
|
||||
feature_extractor=processor.feature_extractor,
|
||||
generate_kwargs={"language": "en"},
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
transcription_non_ass = pipe(sample.copy(), generate_kwargs={"assistant_model": assistant_model})["text"]
|
||||
total_time_assist = time.time() - start_time
|
||||
|
||||
start_time = time.time()
|
||||
transcription_ass = pipe(sample)["text"]
|
||||
total_time_non_assist = time.time() - start_time
|
||||
|
||||
self.assertEqual(transcription_ass, transcription_non_ass)
|
||||
self.assertEqual(
|
||||
transcription_ass,
|
||||
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
|
||||
)
|
||||
self.assertEqual(total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster")
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
|
||||
Reference in New Issue
Block a user