[Whisper] Refactor whisper (#21252)
* update whisper logit processor * add generate for whisper * remove part of the whisper specific code from pipeline * update logit processes * major update * enforce first timestamp * update generate * add more tests * update new decoding strategy * Apply suggestions from code review * update docstring * fixup * default config will not have multilingual ar * update expected tokenizer size, see pull on the hub for whisper-tiny
This commit is contained in:
@@ -31,8 +31,6 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
|
||||
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
|
||||
|
||||
|
||||
@@ -413,13 +411,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
if return_timestamps is not None:
|
||||
forward_params["return_timestamps"] = return_timestamps
|
||||
postprocess_params["return_timestamps"] = return_timestamps
|
||||
if self.model.config.model_type == "whisper":
|
||||
# Whisper is highly specific, if we want timestamps, we need to
|
||||
# force whisper to output timestamp tokens, which means we need
|
||||
# to set this variable to prevent `no_timestamp_token` to be
|
||||
# used in the decoder.
|
||||
if "forced_decoder_ids" not in forward_params.get("generate_kwargs", {}):
|
||||
forward_params["generate_kwargs"]["forced_decoder_ids"] = None
|
||||
|
||||
return preprocess_params, forward_params, postprocess_params
|
||||
|
||||
@@ -529,10 +520,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None):
|
||||
if generate_kwargs is None:
|
||||
generate_kwargs = {}
|
||||
|
||||
if return_timestamps and self.type == "seq2seq_whisper":
|
||||
generate_kwargs["return_timestamps"] = return_timestamps
|
||||
is_last = model_inputs.pop("is_last")
|
||||
|
||||
if self.type == "seq2seq":
|
||||
if self.type in {"seq2seq", "seq2seq_whisper"}:
|
||||
encoder = self.model.get_encoder()
|
||||
# Consume values so we can let extra information flow freely through
|
||||
# the pipeline (important for `partial` in microphone)
|
||||
@@ -557,16 +549,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
**generate_kwargs,
|
||||
)
|
||||
out = {"tokens": tokens}
|
||||
elif self.type == "seq2seq_whisper":
|
||||
stride = model_inputs.pop("stride", None)
|
||||
tokens = self.model.generate(
|
||||
input_features=model_inputs.pop("input_features"),
|
||||
logits_processor=[WhisperTimeStampLogitsProcessor()] if return_timestamps else None,
|
||||
**generate_kwargs,
|
||||
)
|
||||
out = {"tokens": tokens}
|
||||
if stride is not None:
|
||||
out["stride"] = stride
|
||||
if self.type == "seq2seq_whisper":
|
||||
stride = model_inputs.pop("stride", None)
|
||||
if stride is not None:
|
||||
out["stride"] = stride
|
||||
|
||||
else:
|
||||
stride = model_inputs.pop("stride", None)
|
||||
|
||||
Reference in New Issue
Block a user