Pipeline ASR with LM. (#15071)
* Pipeline ASR with LM. * Revamped into `self.decoder`. * Fixing. * 2nd fix. * Update src/transformers/pipelines/__init__.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Fixing. Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -611,6 +611,27 @@ def pipeline(
|
||||
feature_extractor, revision=revision, _from_pipeline=task, **model_kwargs
|
||||
)
|
||||
|
||||
if (
|
||||
feature_extractor._processor_class
|
||||
and feature_extractor._processor_class.endswith("WithLM")
|
||||
and isinstance(model_name, str)
|
||||
):
|
||||
try:
|
||||
from pyctcdecode import BeamSearchDecoderCTC
|
||||
|
||||
language_model_glob = os.path.join(BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*")
|
||||
alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME
|
||||
allow_regex = [language_model_glob, alphabet_filename]
|
||||
|
||||
decoder = BeamSearchDecoderCTC.load_from_hf_hub(
|
||||
pretrained_model_name_or_path, allow_regex=allow_regex
|
||||
)
|
||||
kwargs["decoder"] = decoder
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Try to install `pyctcdecode` and `kenlm`: (`pip install pyctcdecode`, `pip install https://github.com/kpu/kenlm/archive/master.zip`): Error: {e}"
|
||||
)
|
||||
|
||||
if task == "translation" and model.config.task_specific_params:
|
||||
for key in model.config.task_specific_params:
|
||||
if key.startswith("translation"):
|
||||
|
||||
@@ -144,7 +144,18 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.")
|
||||
|
||||
self.check_model_type(dict(MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.items() + MODEL_FOR_CTC_MAPPING.items()))
|
||||
self.is_ctc = self.model.__class__ in MODEL_FOR_CTC_MAPPING.values()
|
||||
|
||||
if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
|
||||
self.type = "seq2seq"
|
||||
elif (
|
||||
self.feature_extractor._processor_class
|
||||
and self.feature_extractor._processor_class.endswith("WithLM")
|
||||
and kwargs.get("decoder", None) is not None
|
||||
):
|
||||
self.decoder = kwargs["decoder"]
|
||||
self.type = "ctc_with_lm"
|
||||
else:
|
||||
self.type = "ctc"
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -222,8 +233,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
is_last = model_inputs.pop("is_last")
|
||||
model_class = self.model.__class__
|
||||
if model_class in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
|
||||
if self.type == "seq2seq":
|
||||
encoder = self.model.get_encoder()
|
||||
# we need to pass `processed.get("attention_mask")` here since audio encoder
|
||||
# attention mask length is different from expected text decoder `encoder_attention_mask` length
|
||||
@@ -232,7 +242,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
tokens = self.model.generate(
|
||||
encoder_outputs=encoder(**model_inputs), attention_mask=model_inputs.get("attention_mask")
|
||||
)
|
||||
elif model_class in MODEL_FOR_CTC_MAPPING.values():
|
||||
out = {"tokens": tokens}
|
||||
elif self.type == "ctc_with_lm":
|
||||
outputs = self.model(**model_inputs)
|
||||
out = {"logits": outputs.logits}
|
||||
|
||||
elif self.type == "ctc":
|
||||
stride = model_inputs.pop("stride", None)
|
||||
outputs = self.model(**model_inputs)
|
||||
tokens = outputs.logits.argmax(dim=-1)
|
||||
@@ -241,16 +256,22 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
stride = [stride]
|
||||
|
||||
apply_stride(tokens, stride)
|
||||
out = {"tokens": tokens}
|
||||
else:
|
||||
logger.warning("This is an unknown class, treating it as CTC.")
|
||||
outputs = self.model(**model_inputs)
|
||||
tokens = outputs.logits.argmax(dim=-1)
|
||||
return {"tokens": tokens, "is_last": is_last}
|
||||
out = {"tokens": tokens}
|
||||
return {"is_last": is_last, **out}
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
skip_special_tokens = False if "CTC" in self.tokenizer.__class__.__name__ else True
|
||||
if self.type == "ctc_with_lm":
|
||||
logits = np.concatenate([outputs["logits"].numpy() for outputs in model_outputs], axis=1)
|
||||
logits = logits.squeeze(0)
|
||||
text = self.decoder.decode_beams(logits)[0][0]
|
||||
else:
|
||||
skip_special_tokens = self.type != "ctc"
|
||||
tokens = np.concatenate([outputs["tokens"].numpy() for outputs in model_outputs], axis=-1)
|
||||
tokens = tokens.squeeze(0)
|
||||
|
||||
recognized_string = self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
||||
return {"text": recognized_string}
|
||||
text = self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
||||
return {"text": text}
|
||||
|
||||
@@ -32,6 +32,7 @@ from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
is_torch_available,
|
||||
nested_simplify,
|
||||
require_pyctcdecode,
|
||||
require_tf,
|
||||
require_torch,
|
||||
require_torchaudio,
|
||||
@@ -97,6 +98,37 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
output = speech_recognizer(waveform)
|
||||
self.assertEqual(output, {"text": "(Applaudissements)"})
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_pyctcdecode
|
||||
def test_large_model_pt_with_lm(self):
|
||||
dataset = load_dataset("Narsil/asr_dummy")
|
||||
filename = dataset["test"][3]["file"]
|
||||
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm",
|
||||
framework="pt",
|
||||
)
|
||||
self.assertEqual(speech_recognizer.type, "ctc_with_lm")
|
||||
|
||||
output = speech_recognizer(filename)
|
||||
self.assertEqual(
|
||||
output,
|
||||
{"text": "y en las ramas medio sumergidas revoloteaban algunos pájaros de quimérico y legendario plumaje"},
|
||||
)
|
||||
|
||||
# Override back to pure CTC
|
||||
speech_recognizer.type = "ctc"
|
||||
output = speech_recognizer(filename)
|
||||
# plumajre != plumaje
|
||||
self.assertEqual(
|
||||
output,
|
||||
{
|
||||
"text": "y en las ramas medio sumergidas revoloteaban algunos pájaros de quimérico y legendario plumajre"
|
||||
},
|
||||
)
|
||||
|
||||
@require_tf
|
||||
def test_small_model_tf(self):
|
||||
self.skipTest("Tensorflow not supported yet.")
|
||||
|
||||
Reference in New Issue
Block a user