Add generate kwargs to AutomaticSpeechRecognitionPipeline (#20952)

* Add generate kwargs to AutomaticSpeechRecognitionPipeline

* Add test for generation kwargs
This commit is contained in:
bofeng huang
2022-12-31 07:13:28 +01:00
committed by GitHub
parent 9e6da0a7ed
commit 47c9b22d08
2 changed files with 66 additions and 17 deletions

View File

@@ -169,6 +169,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
output = speech_recognizer(waveform)
self.assertEqual(output, {"text": "あл ش 湯 清 ه ܬ া लᆨしث ल eか u w 全 u"})
@require_torch
def test_small_model_pt_seq2seq_gen_kwargs(self):
speech_recognizer = pipeline(
model="hf-internal-testing/tiny-random-speech-encoder-decoder",
framework="pt",
)
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
output = speech_recognizer(waveform, max_new_tokens=10, generate_kwargs={"num_beams": 2})
self.assertEqual(output, {"text": "あл † γ ت ב オ 束 泣 足"})
@slow
@require_torch
@require_pyctcdecode