Pipeline: simple API for assisted generation (#34504)
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
@@ -1933,6 +1933,20 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_pipeline_assisted_generation(self):
|
||||
"""Tests that we can run assisted generation in the pipeline"""
|
||||
model = "openai/whisper-tiny"
|
||||
pipe = pipeline("automatic-speech-recognition", model=model, assistant_model=model)
|
||||
|
||||
# We can run the pipeline
|
||||
prompt = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]")["audio"]
|
||||
_ = pipe(prompt)
|
||||
|
||||
# It is running assisted generation under the hood (e.g. flags incompatible with assisted gen will crash)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = pipe(prompt, generate_kwargs={"num_beams": 2})
|
||||
|
||||
|
||||
def require_ffmpeg(test_case):
|
||||
"""
|
||||
|
||||
@@ -653,3 +653,17 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
with CaptureLogger(logger) as cl:
|
||||
_ = text_generator(prompt, max_length=10)
|
||||
self.assertNotIn(logger_msg, cl.out)
|
||||
|
||||
@require_torch
|
||||
def test_pipeline_assisted_generation(self):
|
||||
"""Tests that we can run assisted generation in the pipeline"""
|
||||
model = "hf-internal-testing/tiny-random-MistralForCausalLM"
|
||||
pipe = pipeline("text-generation", model=model, assistant_model=model)
|
||||
|
||||
# We can run the pipeline
|
||||
prompt = "Hello world"
|
||||
_ = pipe(prompt)
|
||||
|
||||
# It is running assisted generation under the hood (e.g. flags incompatible with assisted gen will crash)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = pipe(prompt, generate_kwargs={"num_beams": 2})
|
||||
|
||||
Reference in New Issue
Block a user