Using assistant in AutomaticSpeechRecognitionPipeline with different encoder size (#30637)

* fiw input to generate in pipeline

* fixup

* pass input_features to generate with assistant

* error if model and assistant with different enc size

* fix

* apply review suggestions

* use self.config.is_encoder_decoder

* pass inputs to generate directly

* add slow tests

* Update src/transformers/generation/utils.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update tests/pipelines/test_pipelines_automatic_speech_recognition.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update tests/pipelines/test_pipelines_automatic_speech_recognition.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update tests/pipelines/test_pipelines_automatic_speech_recognition.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update tests/pipelines/test_pipelines_automatic_speech_recognition.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update tests/pipelines/test_pipelines_automatic_speech_recognition.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* apply review

* Update src/transformers/generation/utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/pipelines/test_pipelines_automatic_speech_recognition.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* apply code review

* update attributes encoder_xyz to check

* Update src/transformers/generation/utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/generation/utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* add slow test

* solve conflicts

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Kamil Akesbi
2024-05-23 10:59:38 +02:00
committed by GitHub
parent 15585b81a5
commit eb1a77bbb0
4 changed files with 174 additions and 7 deletions

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import unittest
import numpy as np
@@ -23,6 +24,8 @@ from transformers import (
MODEL_FOR_CTC_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
AutoFeatureExtractor,
AutoModelForCausalLM,
AutoModelForSpeechSeq2Seq,
AutoProcessor,
AutoTokenizer,
Speech2TextForConditionalGeneration,
@@ -1138,6 +1141,94 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
{"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."},
)
@slow
def test_speculative_decoding_whisper_non_distil(self):
# Load data:
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]")
sample = dataset[0]["audio"]
# Load model:
model_id = "openai/whisper-large-v2"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
use_safetensors=True,
)
# Load assistant:
assistant_model_id = "openai/whisper-tiny"
assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
assistant_model_id,
use_safetensors=True,
)
# Load pipeline:
pipe = AutomaticSpeechRecognitionPipeline(
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
generate_kwargs={"language": "en"},
)
start_time = time.time()
transcription_non_ass = pipe(sample.copy(), generate_kwargs={"assistant_model": assistant_model})["text"]
total_time_assist = time.time() - start_time
start_time = time.time()
transcription_ass = pipe(sample)["text"]
total_time_non_assist = time.time() - start_time
self.assertEqual(transcription_ass, transcription_non_ass)
self.assertEqual(
transcription_ass,
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
)
self.assertTrue(total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster")
@slow
def test_speculative_decoding_whisper_distil(self):
# Load data:
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]")
sample = dataset[0]["audio"]
# Load model:
model_id = "openai/whisper-large-v2"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
use_safetensors=True,
)
# Load assistant:
assistant_model_id = "distil-whisper/distil-large-v2"
assistant_model = AutoModelForCausalLM.from_pretrained(
assistant_model_id,
use_safetensors=True,
)
# Load pipeline:
pipe = AutomaticSpeechRecognitionPipeline(
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
generate_kwargs={"language": "en"},
)
start_time = time.time()
transcription_non_ass = pipe(sample.copy(), generate_kwargs={"assistant_model": assistant_model})["text"]
total_time_assist = time.time() - start_time
start_time = time.time()
transcription_ass = pipe(sample)["text"]
total_time_non_assist = time.time() - start_time
self.assertEqual(transcription_ass, transcription_non_ass)
self.assertEqual(
transcription_ass,
" Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
)
self.assertEqual(total_time_non_assist > total_time_assist, "Make sure that assistant decoding is faster")
@slow
@require_torch
@require_torchaudio