Add generate kwargs to AutomaticSpeechRecognitionPipeline (#20952)
* Add generate kwargs to AutomaticSpeechRecognitionPipeline * Add test for generation kwargs
This commit is contained in:
@@ -24,6 +24,8 @@ from .base import ChunkPipeline
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from pyctcdecode import BeamSearchDecoderCTC
|
||||||
|
|
||||||
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -169,8 +171,14 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, feature_extractor: Union["SequenceFeatureExtractor", str], *args, **kwargs):
|
def __init__(
|
||||||
super().__init__(*args, **kwargs)
|
self,
|
||||||
|
feature_extractor: Union["SequenceFeatureExtractor", str],
|
||||||
|
*,
|
||||||
|
decoder: Optional[Union["BeamSearchDecoderCTC", str]] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
self.feature_extractor = feature_extractor
|
self.feature_extractor = feature_extractor
|
||||||
|
|
||||||
if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
|
if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
|
||||||
@@ -178,9 +186,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
elif (
|
elif (
|
||||||
feature_extractor._processor_class
|
feature_extractor._processor_class
|
||||||
and feature_extractor._processor_class.endswith("WithLM")
|
and feature_extractor._processor_class.endswith("WithLM")
|
||||||
and kwargs.get("decoder", None) is not None
|
and decoder is not None
|
||||||
):
|
):
|
||||||
self.decoder = kwargs["decoder"]
|
self.decoder = decoder
|
||||||
self.type = "ctc_with_lm"
|
self.type = "ctc_with_lm"
|
||||||
else:
|
else:
|
||||||
self.type = "ctc"
|
self.type = "ctc"
|
||||||
@@ -221,6 +229,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
`timestamps` along the text for every word in the text. For instance if you get `[{"text": "hi ",
|
`timestamps` along the text for every word in the text. For instance if you get `[{"text": "hi ",
|
||||||
"timestamps": (0.5,0.9), {"text": "there", "timestamps": (1.0, .1.5)}]`, then it means the model
|
"timestamps": (0.5,0.9), {"text": "there", "timestamps": (1.0, .1.5)}]`, then it means the model
|
||||||
predicts that the word "hi" was pronounced after `0.5` and before `0.9` seconds.
|
predicts that the word "hi" was pronounced after `0.5` and before `0.9` seconds.
|
||||||
|
generate_kwargs (`dict`, *optional*):
|
||||||
|
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
|
||||||
|
complete overview of generate, check the [following
|
||||||
|
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
|
||||||
|
max_new_tokens (`int`, *optional*):
|
||||||
|
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
`Dict`: A dictionary with the following keys:
|
`Dict`: A dictionary with the following keys:
|
||||||
@@ -233,23 +247,43 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
"""
|
"""
|
||||||
return super().__call__(inputs, **kwargs)
|
return super().__call__(inputs, **kwargs)
|
||||||
|
|
||||||
def _sanitize_parameters(self, **kwargs):
|
def _sanitize_parameters(
|
||||||
|
self,
|
||||||
|
chunk_length_s=None,
|
||||||
|
stride_length_s=None,
|
||||||
|
ignore_warning=None,
|
||||||
|
decoder_kwargs=None,
|
||||||
|
return_timestamps=None,
|
||||||
|
generate_kwargs=None,
|
||||||
|
max_new_tokens=None,
|
||||||
|
):
|
||||||
# No parameters on this pipeline right now
|
# No parameters on this pipeline right now
|
||||||
preprocess_params = {}
|
preprocess_params = {}
|
||||||
if "chunk_length_s" in kwargs:
|
if chunk_length_s is not None:
|
||||||
preprocess_params["chunk_length_s"] = kwargs["chunk_length_s"]
|
preprocess_params["chunk_length_s"] = chunk_length_s
|
||||||
if "stride_length_s" in kwargs:
|
if stride_length_s is not None:
|
||||||
preprocess_params["stride_length_s"] = kwargs["stride_length_s"]
|
preprocess_params["stride_length_s"] = stride_length_s
|
||||||
if "ignore_warning" in kwargs:
|
if ignore_warning is not None:
|
||||||
preprocess_params["ignore_warning"] = kwargs["ignore_warning"]
|
preprocess_params["ignore_warning"] = ignore_warning
|
||||||
|
|
||||||
|
forward_params = {"generate_kwargs": {}}
|
||||||
|
if max_new_tokens is not None:
|
||||||
|
forward_params["generate_kwargs"]["max_new_tokens"] = max_new_tokens
|
||||||
|
if generate_kwargs is not None:
|
||||||
|
if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
"`max_new_tokens` is defined both as an argument and inside `generate_kwargs` argument, please use"
|
||||||
|
" only 1 version"
|
||||||
|
)
|
||||||
|
forward_params["generate_kwargs"].update(generate_kwargs)
|
||||||
|
|
||||||
postprocess_params = {}
|
postprocess_params = {}
|
||||||
if "decoder_kwargs" in kwargs:
|
if decoder_kwargs is not None:
|
||||||
postprocess_params["decoder_kwargs"] = kwargs["decoder_kwargs"]
|
postprocess_params["decoder_kwargs"] = decoder_kwargs
|
||||||
if "return_timestamps" in kwargs:
|
if return_timestamps is not None:
|
||||||
postprocess_params["return_timestamps"] = kwargs["return_timestamps"]
|
postprocess_params["return_timestamps"] = return_timestamps
|
||||||
|
|
||||||
return preprocess_params, {}, postprocess_params
|
return preprocess_params, forward_params, postprocess_params
|
||||||
|
|
||||||
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False):
|
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False):
|
||||||
if isinstance(inputs, str):
|
if isinstance(inputs, str):
|
||||||
@@ -351,7 +385,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
processed["stride"] = stride
|
processed["stride"] = stride
|
||||||
yield {"is_last": True, **processed, **extra}
|
yield {"is_last": True, **processed, **extra}
|
||||||
|
|
||||||
def _forward(self, model_inputs):
|
def _forward(self, model_inputs, generate_kwargs=None):
|
||||||
|
if generate_kwargs is None:
|
||||||
|
generate_kwargs = {}
|
||||||
|
|
||||||
is_last = model_inputs.pop("is_last")
|
is_last = model_inputs.pop("is_last")
|
||||||
if self.type == "seq2seq":
|
if self.type == "seq2seq":
|
||||||
encoder = self.model.get_encoder()
|
encoder = self.model.get_encoder()
|
||||||
@@ -376,6 +413,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
tokens = self.model.generate(
|
tokens = self.model.generate(
|
||||||
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
|
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
**generate_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
out = {"tokens": tokens}
|
out = {"tokens": tokens}
|
||||||
|
|||||||
@@ -169,6 +169,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
output = speech_recognizer(waveform)
|
output = speech_recognizer(waveform)
|
||||||
self.assertEqual(output, {"text": "あл ش 湯 清 ه ܬ া लᆨしث ल eか u w 全 u"})
|
self.assertEqual(output, {"text": "あл ش 湯 清 ه ܬ া लᆨしث ल eか u w 全 u"})
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_small_model_pt_seq2seq_gen_kwargs(self):
|
||||||
|
speech_recognizer = pipeline(
|
||||||
|
model="hf-internal-testing/tiny-random-speech-encoder-decoder",
|
||||||
|
framework="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
|
||||||
|
output = speech_recognizer(waveform, max_new_tokens=10, generate_kwargs={"num_beams": 2})
|
||||||
|
self.assertEqual(output, {"text": "あл † γ ت ב オ 束 泣 足"})
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_pyctcdecode
|
@require_pyctcdecode
|
||||||
|
|||||||
Reference in New Issue
Block a user