Pipeline: simple API for assisted generation (#34504)

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
Joao Gante
2025-01-08 17:08:02 +00:00
committed by GitHub
parent 3f483beab9
commit 76da6ca034
14 changed files with 172 additions and 18 deletions

View File

@@ -1933,6 +1933,20 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
},
)
@require_torch
def test_pipeline_assisted_generation(self):
"""Tests that we can run assisted generation in the pipeline"""
model = "openai/whisper-tiny"
pipe = pipeline("automatic-speech-recognition", model=model, assistant_model=model)
# We can run the pipeline
prompt = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]")["audio"]
_ = pipe(prompt)
# It is running assisted generation under the hood (e.g. flags incompatible with assisted gen will crash)
with self.assertRaises(ValueError):
_ = pipe(prompt, generate_kwargs={"num_beams": 2})
def require_ffmpeg(test_case):
"""

View File

@@ -653,3 +653,17 @@ class TextGenerationPipelineTests(unittest.TestCase):
with CaptureLogger(logger) as cl:
_ = text_generator(prompt, max_length=10)
self.assertNotIn(logger_msg, cl.out)
@require_torch
def test_pipeline_assisted_generation(self):
"""Tests that we can run assisted generation in the pipeline"""
model = "hf-internal-testing/tiny-random-MistralForCausalLM"
pipe = pipeline("text-generation", model=model, assistant_model=model)
# We can run the pipeline
prompt = "Hello world"
_ = pipe(prompt)
# It is running assisted generation under the hood (e.g. flags incompatible with assisted gen will crash)
with self.assertRaises(ValueError):
_ = pipe(prompt, generate_kwargs={"num_beams": 2})