[ASR pipeline] correct asr pipeline for seq2seq models (#15541)
This commit is contained in:
committed by
GitHub
parent
e02bdce791
commit
5f1918a4a8
@@ -265,10 +265,19 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
# it here.
|
||||
# Consume values so we can let extra information flow freely through
|
||||
# the pipeline (important for `partial` in microphone)
|
||||
input_features = model_inputs.pop("input_features")
|
||||
attention_mask = model_inputs.pop("attention_mask")
|
||||
if "input_features" in model_inputs:
|
||||
inputs = model_inputs.pop("input_features")
|
||||
elif "input_values" in model_inputs:
|
||||
inputs = model_inputs.pop("input_values")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Seq2Seq speech recognition model requires either a "
|
||||
f"`input_features` or `input_values` key, but only has {model_inputs.keys()}"
|
||||
)
|
||||
|
||||
attention_mask = model_inputs.pop("attention_mask", None)
|
||||
tokens = self.model.generate(
|
||||
encoder_outputs=encoder(input_features=input_features, attention_mask=attention_mask),
|
||||
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
out = {"tokens": tokens}
|
||||
|
||||
@@ -107,6 +107,24 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
output = speech_recognizer(waveform)
|
||||
self.assertEqual(output, {"text": "(Applaudissements)"})
|
||||
|
||||
@require_torch
|
||||
def test_small_model_pt_seq2seq(self):
|
||||
model_id = "hf-internal-testing/tiny-random-speech-encoder-decoder"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
|
||||
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model=model_id,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
framework="pt",
|
||||
)
|
||||
|
||||
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
|
||||
output = speech_recognizer(waveform)
|
||||
self.assertEqual(output, {"text": "あл ش 湯 清 ه ܬ া लᆨしث ल eか u w 全 u"})
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_pyctcdecode
|
||||
|
||||
Reference in New Issue
Block a user