[ASR pipeline] correct asr pipeline for seq2seq models (#15541)
This commit is contained in:
committed by
GitHub
parent
e02bdce791
commit
5f1918a4a8
@@ -265,10 +265,19 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
# it here.
|
# it here.
|
||||||
# Consume values so we can let extra information flow freely through
|
# Consume values so we can let extra information flow freely through
|
||||||
# the pipeline (important for `partial` in microphone)
|
# the pipeline (important for `partial` in microphone)
|
||||||
input_features = model_inputs.pop("input_features")
|
if "input_features" in model_inputs:
|
||||||
attention_mask = model_inputs.pop("attention_mask")
|
inputs = model_inputs.pop("input_features")
|
||||||
|
elif "input_values" in model_inputs:
|
||||||
|
inputs = model_inputs.pop("input_values")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Seq2Seq speech recognition model requires either a "
|
||||||
|
f"`input_features` or `input_values` key, but only has {model_inputs.keys()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attention_mask = model_inputs.pop("attention_mask", None)
|
||||||
tokens = self.model.generate(
|
tokens = self.model.generate(
|
||||||
encoder_outputs=encoder(input_features=input_features, attention_mask=attention_mask),
|
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
)
|
)
|
||||||
out = {"tokens": tokens}
|
out = {"tokens": tokens}
|
||||||
|
|||||||
@@ -107,6 +107,24 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
output = speech_recognizer(waveform)
|
output = speech_recognizer(waveform)
|
||||||
self.assertEqual(output, {"text": "(Applaudissements)"})
|
self.assertEqual(output, {"text": "(Applaudissements)"})
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_model_pt_seq2seq(self):
|
||||||
|
model_id = "hf-internal-testing/tiny-random-speech-encoder-decoder"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
|
||||||
|
|
||||||
|
speech_recognizer = pipeline(
|
||||||
|
task="automatic-speech-recognition",
|
||||||
|
model=model_id,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
framework="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
|
||||||
|
output = speech_recognizer(waveform)
|
||||||
|
self.assertEqual(output, {"text": "あл ش 湯 清 ه ܬ া लᆨしث ल eか u w 全 u"})
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_pyctcdecode
|
@require_pyctcdecode
|
||||||
|
|||||||
Reference in New Issue
Block a user