[Whisper] Refactor whisper (#21252)

* update whisper logit processor

* add generate for whisper

* remove part of the whisper specific code from pipeline

* update logit processes

* major update

* enforce first timestamp

* update generate

* add more tests

* update new decoding strategy

* Apply suggestions from code review

* update docstring

* fixup

* default config will not have multilingual ar

* update expected tokenizer size, see pull on the hub for whisper-tiny
This commit is contained in:
Arthur
2023-01-25 13:09:43 +01:00
committed by GitHub
parent f83135eb76
commit 255257f3ea
6 changed files with 231 additions and 55 deletions

View File

@@ -31,8 +31,6 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
if is_torch_available():
from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
from ..models.auto.modeling_auto import MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
@@ -413,13 +411,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if return_timestamps is not None:
forward_params["return_timestamps"] = return_timestamps
postprocess_params["return_timestamps"] = return_timestamps
if self.model.config.model_type == "whisper":
# Whisper is highly specific, if we want timestamps, we need to
# force whisper to output timestamp tokens, which means we need
# to set this variable to prevent `no_timestamp_token` to be
# used in the decoder.
if "forced_decoder_ids" not in forward_params.get("generate_kwargs", {}):
forward_params["generate_kwargs"]["forced_decoder_ids"] = None
return preprocess_params, forward_params, postprocess_params
@@ -529,10 +520,11 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
def _forward(self, model_inputs, return_timestamps=False, generate_kwargs=None):
if generate_kwargs is None:
generate_kwargs = {}
if return_timestamps and self.type == "seq2seq_whisper":
generate_kwargs["return_timestamps"] = return_timestamps
is_last = model_inputs.pop("is_last")
if self.type == "seq2seq":
if self.type in {"seq2seq", "seq2seq_whisper"}:
encoder = self.model.get_encoder()
# Consume values so we can let extra information flow freely through
# the pipeline (important for `partial` in microphone)
@@ -557,16 +549,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
**generate_kwargs,
)
out = {"tokens": tokens}
elif self.type == "seq2seq_whisper":
stride = model_inputs.pop("stride", None)
tokens = self.model.generate(
input_features=model_inputs.pop("input_features"),
logits_processor=[WhisperTimeStampLogitsProcessor()] if return_timestamps else None,
**generate_kwargs,
)
out = {"tokens": tokens}
if stride is not None:
out["stride"] = stride
if self.type == "seq2seq_whisper":
stride = model_inputs.pop("stride", None)
if stride is not None:
out["stride"] = stride
else:
stride = model_inputs.pop("stride", None)