Add decoder_kwargs to send to LM on asr pipeline. (#15646)
Co-authored-by: Giuseppe Attanasio <giuseppeattanasio6@gmail.com> Co-authored-by: Giuseppe Attanasio <giuseppeattanasio6@gmail.com>
This commit is contained in:
@@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING, Dict, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -180,7 +180,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
if "stride_length_s" in kwargs:
|
if "stride_length_s" in kwargs:
|
||||||
preprocess_params["stride_length_s"] = kwargs["stride_length_s"]
|
preprocess_params["stride_length_s"] = kwargs["stride_length_s"]
|
||||||
|
|
||||||
return preprocess_params, {}, {}
|
postprocess_params = {}
|
||||||
|
if "decoder_kwargs" in kwargs:
|
||||||
|
postprocess_params["decoder_kwargs"] = kwargs["decoder_kwargs"]
|
||||||
|
|
||||||
|
return preprocess_params, {}, postprocess_params
|
||||||
|
|
||||||
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
|
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
|
||||||
if isinstance(inputs, str):
|
if isinstance(inputs, str):
|
||||||
@@ -319,7 +323,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
extra = model_inputs
|
extra = model_inputs
|
||||||
return {"is_last": is_last, **out, **extra}
|
return {"is_last": is_last, **out, **extra}
|
||||||
|
|
||||||
def postprocess(self, model_outputs):
|
def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None):
|
||||||
if self.type == "ctc_with_lm":
|
if self.type == "ctc_with_lm":
|
||||||
final_logits = []
|
final_logits = []
|
||||||
for outputs in model_outputs:
|
for outputs in model_outputs:
|
||||||
@@ -334,9 +338,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
right_n = total_n - right
|
right_n = total_n - right
|
||||||
logits = logits[:, left:right_n]
|
logits = logits[:, left:right_n]
|
||||||
final_logits.append(logits)
|
final_logits.append(logits)
|
||||||
|
if decoder_kwargs is None:
|
||||||
|
decoder_kwargs = {}
|
||||||
logits = np.concatenate(final_logits, axis=1)
|
logits = np.concatenate(final_logits, axis=1)
|
||||||
logits = logits.squeeze(0)
|
logits = logits.squeeze(0)
|
||||||
text = self.decoder.decode_beams(logits)[0][0]
|
text = self.decoder.decode_beams(logits, **decoder_kwargs)[0][0]
|
||||||
else:
|
else:
|
||||||
skip_special_tokens = self.type != "ctc"
|
skip_special_tokens = self.type != "ctc"
|
||||||
tokens = np.concatenate([outputs["tokens"].numpy() for outputs in model_outputs], axis=-1)
|
tokens = np.concatenate([outputs["tokens"].numpy() for outputs in model_outputs], axis=-1)
|
||||||
|
|||||||
@@ -365,10 +365,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
audio_tiled = np.tile(audio, n_repeats)
|
audio_tiled = np.tile(audio, n_repeats)
|
||||||
|
|
||||||
output = speech_recognizer([audio_tiled], batch_size=2)
|
output = speech_recognizer([audio_tiled], batch_size=2)
|
||||||
|
|
||||||
self.assertEqual(output, [{"text": ANY(str)}])
|
self.assertEqual(output, [{"text": ANY(str)}])
|
||||||
self.assertEqual(output[0]["text"][:6], "<s> <s")
|
self.assertEqual(output[0]["text"][:6], "<s> <s")
|
||||||
|
|
||||||
|
# Making sure the argument are passed to the decoder
|
||||||
|
# Since no change happens in the result, check the error comes from
|
||||||
|
# the `decode_beams` function.
|
||||||
|
with self.assertRaises(TypeError) as e:
|
||||||
|
output = speech_recognizer([audio_tiled], decoder_kwargs={"num_beams": 2})
|
||||||
|
self.assertContains(e.msg, "TypeError: decode_beams() got an unexpected keyword argument 'num_beams'")
|
||||||
|
output = speech_recognizer([audio_tiled], decoder_kwargs={"beam_width": 2})
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_pyctcdecode
|
@require_pyctcdecode
|
||||||
def test_with_local_lm_fast(self):
|
def test_with_local_lm_fast(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user