[ASR Pipeline] Fix init with timestamps (#25438)
* [ASR Pipeline] Fix init * refactor test * change default kwarg setting * only perform checks if we have to * override init * move pre/forward/post checks to sanitize
This commit is contained in:
@@ -17,19 +17,24 @@ from typing import TYPE_CHECKING, Dict, Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from ..modelcard import ModelCard
|
||||||
|
from ..tokenization_utils import PreTrainedTokenizer
|
||||||
from ..utils import is_torch_available, is_torchaudio_available, logging
|
from ..utils import is_torch_available, is_torchaudio_available, logging
|
||||||
from .audio_utils import ffmpeg_read
|
from .audio_utils import ffmpeg_read
|
||||||
from .base import ChunkPipeline
|
from .base import ArgumentHandler, ChunkPipeline, infer_framework_load_model
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pyctcdecode import BeamSearchDecoderCTC
|
from pyctcdecode import BeamSearchDecoderCTC
|
||||||
|
|
||||||
from ..feature_extraction_sequence_utils import SequenceFeatureExtractor
|
from ..feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||||
|
from ..modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
|
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
|
||||||
|
|
||||||
|
|
||||||
@@ -194,14 +199,78 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
feature_extractor: Union["SequenceFeatureExtractor", str],
|
model: "PreTrainedModel",
|
||||||
*,
|
feature_extractor: Union["SequenceFeatureExtractor", str] = None,
|
||||||
|
tokenizer: Optional[PreTrainedTokenizer] = None,
|
||||||
decoder: Optional[Union["BeamSearchDecoderCTC", str]] = None,
|
decoder: Optional[Union["BeamSearchDecoderCTC", str]] = None,
|
||||||
|
modelcard: Optional[ModelCard] = None,
|
||||||
|
framework: Optional[str] = None,
|
||||||
|
task: str = "",
|
||||||
|
args_parser: ArgumentHandler = None,
|
||||||
|
device: Union[int, "torch.device"] = None,
|
||||||
|
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
|
||||||
|
binary_output: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
if framework is None:
|
||||||
self.feature_extractor = feature_extractor
|
framework, model = infer_framework_load_model(model, config=model.config)
|
||||||
|
|
||||||
|
self.task = task
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.feature_extractor = feature_extractor
|
||||||
|
self.modelcard = modelcard
|
||||||
|
self.framework = framework
|
||||||
|
|
||||||
|
# `accelerate` device map
|
||||||
|
hf_device_map = getattr(self.model, "hf_device_map", None)
|
||||||
|
|
||||||
|
if hf_device_map is not None and device is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"The model has been loaded with `accelerate` and therefore cannot be moved to a specific device. Please "
|
||||||
|
"discard the `device` argument when creating your pipeline object."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.framework == "tf":
|
||||||
|
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.")
|
||||||
|
|
||||||
|
# We shouldn't call `model.to()` for models loaded with accelerate
|
||||||
|
if device is not None and not (isinstance(device, int) and device < 0):
|
||||||
|
self.model.to(device)
|
||||||
|
|
||||||
|
if device is None:
|
||||||
|
if hf_device_map is not None:
|
||||||
|
# Take the first device used by `accelerate`.
|
||||||
|
device = next(iter(hf_device_map.values()))
|
||||||
|
else:
|
||||||
|
device = -1
|
||||||
|
|
||||||
|
if is_torch_available() and self.framework == "pt":
|
||||||
|
if isinstance(device, torch.device):
|
||||||
|
self.device = device
|
||||||
|
elif isinstance(device, str):
|
||||||
|
self.device = torch.device(device)
|
||||||
|
elif device < 0:
|
||||||
|
self.device = torch.device("cpu")
|
||||||
|
else:
|
||||||
|
self.device = torch.device(f"cuda:{device}")
|
||||||
|
else:
|
||||||
|
self.device = device if device is not None else -1
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
self.binary_output = binary_output
|
||||||
|
|
||||||
|
# Update config and generation_config with task specific parameters
|
||||||
|
task_specific_params = self.model.config.task_specific_params
|
||||||
|
if task_specific_params is not None and task in task_specific_params:
|
||||||
|
self.model.config.update(task_specific_params.get(task))
|
||||||
|
if self.model.can_generate():
|
||||||
|
self.model.generation_config.update(**task_specific_params.get(task))
|
||||||
|
|
||||||
|
self.call_count = 0
|
||||||
|
self._batch_size = kwargs.pop("batch_size", None)
|
||||||
|
self._num_workers = kwargs.pop("num_workers", None)
|
||||||
|
|
||||||
|
# set the model type so we can check we have the right pre- and post-processing parameters
|
||||||
if self.model.config.model_type == "whisper":
|
if self.model.config.model_type == "whisper":
|
||||||
self.type = "seq2seq_whisper"
|
self.type = "seq2seq_whisper"
|
||||||
elif self.model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values():
|
elif self.model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values():
|
||||||
@@ -216,8 +285,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
else:
|
else:
|
||||||
self.type = "ctc"
|
self.type = "ctc"
|
||||||
|
|
||||||
if self.framework == "tf":
|
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
|
||||||
raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.")
|
|
||||||
|
|
||||||
mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.copy()
|
mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.copy()
|
||||||
mapping.update(MODEL_FOR_CTC_MAPPING_NAMES)
|
mapping.update(MODEL_FOR_CTC_MAPPING_NAMES)
|
||||||
@@ -301,11 +369,16 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
# No parameters on this pipeline right now
|
# No parameters on this pipeline right now
|
||||||
preprocess_params = {}
|
preprocess_params = {}
|
||||||
if chunk_length_s is not None:
|
if chunk_length_s is not None:
|
||||||
|
if self.type == "seq2seq" and not ignore_warning:
|
||||||
|
logger.warning(
|
||||||
|
"Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily"
|
||||||
|
" be entirely accurate and will have caveats. More information:"
|
||||||
|
" https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...,"
|
||||||
|
" ignore_warning=True)"
|
||||||
|
)
|
||||||
preprocess_params["chunk_length_s"] = chunk_length_s
|
preprocess_params["chunk_length_s"] = chunk_length_s
|
||||||
if stride_length_s is not None:
|
if stride_length_s is not None:
|
||||||
preprocess_params["stride_length_s"] = stride_length_s
|
preprocess_params["stride_length_s"] = stride_length_s
|
||||||
if ignore_warning is not None:
|
|
||||||
preprocess_params["ignore_warning"] = ignore_warning
|
|
||||||
|
|
||||||
forward_params = defaultdict(dict)
|
forward_params = defaultdict(dict)
|
||||||
if max_new_tokens is not None:
|
if max_new_tokens is not None:
|
||||||
@@ -322,6 +395,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
if decoder_kwargs is not None:
|
if decoder_kwargs is not None:
|
||||||
postprocess_params["decoder_kwargs"] = decoder_kwargs
|
postprocess_params["decoder_kwargs"] = decoder_kwargs
|
||||||
if return_timestamps is not None:
|
if return_timestamps is not None:
|
||||||
|
# Check whether we have a valid setting for return_timestamps and throw an error before we perform a forward pass
|
||||||
if self.type == "seq2seq" and return_timestamps:
|
if self.type == "seq2seq" and return_timestamps:
|
||||||
raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!")
|
raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!")
|
||||||
if self.type == "ctc_with_lm" and return_timestamps != "word":
|
if self.type == "ctc_with_lm" and return_timestamps != "word":
|
||||||
@@ -339,11 +413,13 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
forward_params["return_timestamps"] = return_timestamps
|
forward_params["return_timestamps"] = return_timestamps
|
||||||
postprocess_params["return_timestamps"] = return_timestamps
|
postprocess_params["return_timestamps"] = return_timestamps
|
||||||
if return_language is not None:
|
if return_language is not None:
|
||||||
|
if self.type != "seq2seq_whisper":
|
||||||
|
raise ValueError("Only Whisper can return language for now.")
|
||||||
postprocess_params["return_language"] = return_language
|
postprocess_params["return_language"] = return_language
|
||||||
|
|
||||||
return preprocess_params, forward_params, postprocess_params
|
return preprocess_params, forward_params, postprocess_params
|
||||||
|
|
||||||
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False):
|
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
|
||||||
if isinstance(inputs, str):
|
if isinstance(inputs, str):
|
||||||
if inputs.startswith("http://") or inputs.startswith("https://"):
|
if inputs.startswith("http://") or inputs.startswith("https://"):
|
||||||
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
||||||
@@ -378,8 +454,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
extra = inputs
|
extra = inputs
|
||||||
inputs = _inputs
|
inputs = _inputs
|
||||||
if in_sampling_rate != self.feature_extractor.sampling_rate:
|
if in_sampling_rate != self.feature_extractor.sampling_rate:
|
||||||
import torch
|
|
||||||
|
|
||||||
if is_torchaudio_available():
|
if is_torchaudio_available():
|
||||||
from torchaudio import functional as F
|
from torchaudio import functional as F
|
||||||
else:
|
else:
|
||||||
@@ -409,14 +483,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
|
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
|
||||||
|
|
||||||
if chunk_length_s:
|
if chunk_length_s:
|
||||||
if self.type == "seq2seq" and not ignore_warning:
|
|
||||||
logger.warning(
|
|
||||||
"Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily"
|
|
||||||
" be entirely accurate and will have caveats. More information:"
|
|
||||||
" https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...,"
|
|
||||||
" ignore_warning=True)"
|
|
||||||
)
|
|
||||||
self._preprocess_params["ignore_warning"] = True
|
|
||||||
if stride_length_s is None:
|
if stride_length_s is None:
|
||||||
stride_length_s = chunk_length_s / 6
|
stride_length_s = chunk_length_s / 6
|
||||||
|
|
||||||
@@ -456,6 +522,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None):
|
def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None):
|
||||||
if generate_kwargs is None:
|
if generate_kwargs is None:
|
||||||
generate_kwargs = {}
|
generate_kwargs = {}
|
||||||
|
|
||||||
if return_timestamps and self.type == "seq2seq_whisper":
|
if return_timestamps and self.type == "seq2seq_whisper":
|
||||||
generate_kwargs["return_timestamps"] = return_timestamps
|
generate_kwargs["return_timestamps"] = return_timestamps
|
||||||
if return_timestamps == "word":
|
if return_timestamps == "word":
|
||||||
@@ -525,9 +592,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
# Optional return types
|
# Optional return types
|
||||||
optional = {}
|
optional = {}
|
||||||
|
|
||||||
if return_language is not None and self.type != "seq2seq_whisper":
|
|
||||||
raise ValueError("Only whisper can return language for now.")
|
|
||||||
|
|
||||||
final_items = []
|
final_items = []
|
||||||
key = "logits" if self.type == "ctc_with_lm" else "tokens"
|
key = "logits" if self.type == "ctc_with_lm" else "tokens"
|
||||||
stride = None
|
stride = None
|
||||||
|
|||||||
@@ -343,6 +343,58 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_return_timestamps_in_init(self):
|
||||||
|
# segment-level timestamps are accepted
|
||||||
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("openai/whisper-tiny")
|
||||||
|
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny")
|
||||||
|
|
||||||
|
dummy_speech = np.ones(100)
|
||||||
|
|
||||||
|
pipe = pipeline(
|
||||||
|
task="automatic-speech-recognition",
|
||||||
|
model=model,
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
chunk_length_s=8,
|
||||||
|
stride_length_s=1,
|
||||||
|
return_timestamps=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
_ = pipe(dummy_speech)
|
||||||
|
|
||||||
|
# word-level timestamps are accepted
|
||||||
|
pipe = pipeline(
|
||||||
|
task="automatic-speech-recognition",
|
||||||
|
model=model,
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
chunk_length_s=8,
|
||||||
|
stride_length_s=1,
|
||||||
|
return_timestamps="word",
|
||||||
|
)
|
||||||
|
|
||||||
|
_ = pipe(dummy_speech)
|
||||||
|
|
||||||
|
# char-level timestamps are not accepted
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
"^Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
|
||||||
|
"Use `return_timestamps='word'` or `return_timestamps=True` respectively.$",
|
||||||
|
):
|
||||||
|
pipe = pipeline(
|
||||||
|
task="automatic-speech-recognition",
|
||||||
|
model=model,
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
chunk_length_s=8,
|
||||||
|
stride_length_s=1,
|
||||||
|
return_timestamps="char",
|
||||||
|
)
|
||||||
|
|
||||||
|
_ = pipe(dummy_speech)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
def test_torch_whisper(self):
|
def test_torch_whisper(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user