Pipeline ASR with LM. (#15071)
* Pipeline ASR with LM. * Revamped into `self.decoder`. * Fixing. * 2nd fix. * Update src/transformers/pipelines/__init__.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Fixing. Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -611,6 +611,27 @@ def pipeline(
|
|||||||
feature_extractor, revision=revision, _from_pipeline=task, **model_kwargs
|
feature_extractor, revision=revision, _from_pipeline=task, **model_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
feature_extractor._processor_class
|
||||||
|
and feature_extractor._processor_class.endswith("WithLM")
|
||||||
|
and isinstance(model_name, str)
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
from pyctcdecode import BeamSearchDecoderCTC
|
||||||
|
|
||||||
|
language_model_glob = os.path.join(BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*")
|
||||||
|
alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME
|
||||||
|
allow_regex = [language_model_glob, alphabet_filename]
|
||||||
|
|
||||||
|
decoder = BeamSearchDecoderCTC.load_from_hf_hub(
|
||||||
|
pretrained_model_name_or_path, allow_regex=allow_regex
|
||||||
|
)
|
||||||
|
kwargs["decoder"] = decoder
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Try to install `pyctcdecode` and `kenlm`: (`pip install pyctcdecode`, `pip install https://github.com/kpu/kenlm/archive/master.zip`): Error: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
if task == "translation" and model.config.task_specific_params:
|
if task == "translation" and model.config.task_specific_params:
|
||||||
for key in model.config.task_specific_params:
|
for key in model.config.task_specific_params:
|
||||||
if key.startswith("translation"):
|
if key.startswith("translation"):
|
||||||
|
|||||||
@@ -144,7 +144,18 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.")
|
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.")
|
||||||
|
|
||||||
self.check_model_type(dict(MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.items() + MODEL_FOR_CTC_MAPPING.items()))
|
self.check_model_type(dict(MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.items() + MODEL_FOR_CTC_MAPPING.items()))
|
||||||
self.is_ctc = self.model.__class__ in MODEL_FOR_CTC_MAPPING.values()
|
|
||||||
|
if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
|
||||||
|
self.type = "seq2seq"
|
||||||
|
elif (
|
||||||
|
self.feature_extractor._processor_class
|
||||||
|
and self.feature_extractor._processor_class.endswith("WithLM")
|
||||||
|
and kwargs.get("decoder", None) is not None
|
||||||
|
):
|
||||||
|
self.decoder = kwargs["decoder"]
|
||||||
|
self.type = "ctc_with_lm"
|
||||||
|
else:
|
||||||
|
self.type = "ctc"
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -222,8 +233,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
|
|
||||||
def _forward(self, model_inputs):
|
def _forward(self, model_inputs):
|
||||||
is_last = model_inputs.pop("is_last")
|
is_last = model_inputs.pop("is_last")
|
||||||
model_class = self.model.__class__
|
if self.type == "seq2seq":
|
||||||
if model_class in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
|
|
||||||
encoder = self.model.get_encoder()
|
encoder = self.model.get_encoder()
|
||||||
# we need to pass `processed.get("attention_mask")` here since audio encoder
|
# we need to pass `processed.get("attention_mask")` here since audio encoder
|
||||||
# attention mask length is different from expected text decoder `encoder_attention_mask` length
|
# attention mask length is different from expected text decoder `encoder_attention_mask` length
|
||||||
@@ -232,7 +242,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
tokens = self.model.generate(
|
tokens = self.model.generate(
|
||||||
encoder_outputs=encoder(**model_inputs), attention_mask=model_inputs.get("attention_mask")
|
encoder_outputs=encoder(**model_inputs), attention_mask=model_inputs.get("attention_mask")
|
||||||
)
|
)
|
||||||
elif model_class in MODEL_FOR_CTC_MAPPING.values():
|
out = {"tokens": tokens}
|
||||||
|
elif self.type == "ctc_with_lm":
|
||||||
|
outputs = self.model(**model_inputs)
|
||||||
|
out = {"logits": outputs.logits}
|
||||||
|
|
||||||
|
elif self.type == "ctc":
|
||||||
stride = model_inputs.pop("stride", None)
|
stride = model_inputs.pop("stride", None)
|
||||||
outputs = self.model(**model_inputs)
|
outputs = self.model(**model_inputs)
|
||||||
tokens = outputs.logits.argmax(dim=-1)
|
tokens = outputs.logits.argmax(dim=-1)
|
||||||
@@ -241,16 +256,22 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
stride = [stride]
|
stride = [stride]
|
||||||
|
|
||||||
apply_stride(tokens, stride)
|
apply_stride(tokens, stride)
|
||||||
|
out = {"tokens": tokens}
|
||||||
else:
|
else:
|
||||||
logger.warning("This is an unknown class, treating it as CTC.")
|
logger.warning("This is an unknown class, treating it as CTC.")
|
||||||
outputs = self.model(**model_inputs)
|
outputs = self.model(**model_inputs)
|
||||||
tokens = outputs.logits.argmax(dim=-1)
|
tokens = outputs.logits.argmax(dim=-1)
|
||||||
return {"tokens": tokens, "is_last": is_last}
|
out = {"tokens": tokens}
|
||||||
|
return {"is_last": is_last, **out}
|
||||||
|
|
||||||
def postprocess(self, model_outputs):
|
def postprocess(self, model_outputs):
|
||||||
skip_special_tokens = False if "CTC" in self.tokenizer.__class__.__name__ else True
|
if self.type == "ctc_with_lm":
|
||||||
|
logits = np.concatenate([outputs["logits"].numpy() for outputs in model_outputs], axis=1)
|
||||||
|
logits = logits.squeeze(0)
|
||||||
|
text = self.decoder.decode_beams(logits)[0][0]
|
||||||
|
else:
|
||||||
|
skip_special_tokens = self.type != "ctc"
|
||||||
tokens = np.concatenate([outputs["tokens"].numpy() for outputs in model_outputs], axis=-1)
|
tokens = np.concatenate([outputs["tokens"].numpy() for outputs in model_outputs], axis=-1)
|
||||||
tokens = tokens.squeeze(0)
|
tokens = tokens.squeeze(0)
|
||||||
|
text = self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
||||||
recognized_string = self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
return {"text": text}
|
||||||
return {"text": recognized_string}
|
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from transformers.testing_utils import (
|
|||||||
is_pipeline_test,
|
is_pipeline_test,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
nested_simplify,
|
nested_simplify,
|
||||||
|
require_pyctcdecode,
|
||||||
require_tf,
|
require_tf,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torchaudio,
|
require_torchaudio,
|
||||||
@@ -97,6 +98,37 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
output = speech_recognizer(waveform)
|
output = speech_recognizer(waveform)
|
||||||
self.assertEqual(output, {"text": "(Applaudissements)"})
|
self.assertEqual(output, {"text": "(Applaudissements)"})
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
@require_pyctcdecode
|
||||||
|
def test_large_model_pt_with_lm(self):
|
||||||
|
dataset = load_dataset("Narsil/asr_dummy")
|
||||||
|
filename = dataset["test"][3]["file"]
|
||||||
|
|
||||||
|
speech_recognizer = pipeline(
|
||||||
|
task="automatic-speech-recognition",
|
||||||
|
model="patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm",
|
||||||
|
framework="pt",
|
||||||
|
)
|
||||||
|
self.assertEqual(speech_recognizer.type, "ctc_with_lm")
|
||||||
|
|
||||||
|
output = speech_recognizer(filename)
|
||||||
|
self.assertEqual(
|
||||||
|
output,
|
||||||
|
{"text": "y en las ramas medio sumergidas revoloteaban algunos pájaros de quimérico y legendario plumaje"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Override back to pure CTC
|
||||||
|
speech_recognizer.type = "ctc"
|
||||||
|
output = speech_recognizer(filename)
|
||||||
|
# plumajre != plumaje
|
||||||
|
self.assertEqual(
|
||||||
|
output,
|
||||||
|
{
|
||||||
|
"text": "y en las ramas medio sumergidas revoloteaban algunos pájaros de quimérico y legendario plumajre"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
def test_small_model_tf(self):
|
def test_small_model_tf(self):
|
||||||
self.skipTest("Tensorflow not supported yet.")
|
self.skipTest("Tensorflow not supported yet.")
|
||||||
|
|||||||
Reference in New Issue
Block a user