Fix whisper for pipeline (#19482)
* update feature extractor params * update attention mask handling * fix doc and pipeline test * add warning when skipping test * add whisper translation and transcription test * fix build doc test
This commit is contained in:
@@ -218,6 +218,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
padding: Optional[str] = "max_length",
|
||||
max_length: Optional[int] = None,
|
||||
sampling_rate: Optional[int] = None,
|
||||
**kwargs
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
@@ -261,6 +262,19 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
||||
The value that is used to fill the padding values / vectors.
|
||||
"""
|
||||
|
||||
if sampling_rate is not None:
|
||||
if sampling_rate != self.sampling_rate:
|
||||
raise ValueError(
|
||||
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
|
||||
f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
|
||||
f" {self.sampling_rate} and not {sampling_rate}."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"It is strongly recommended to pass the `sampling_rate` argument to this function. "
|
||||
"Failing to do so can result in silent errors that might be hard to debug."
|
||||
)
|
||||
|
||||
is_batched = bool(
|
||||
isinstance(raw_speech, (list, tuple))
|
||||
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
|
||||
|
||||
@@ -31,13 +31,22 @@ from ...modeling_outputs import (
|
||||
Seq2SeqModelOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_whisper import WhisperConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "WhisperConfig"
|
||||
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
|
||||
_PROCESSOR_FOR_DOC = "openai/whisper-tiny"
|
||||
_EXPECTED_OUTPUT_SHAPE = [1, 2, 512]
|
||||
|
||||
|
||||
WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
@@ -982,7 +991,14 @@ class WhisperModel(WhisperPreTrainedModel):
|
||||
return self.decoder
|
||||
|
||||
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_PROCESSOR_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=Seq2SeqModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||
modality="audio",
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_features=None,
|
||||
@@ -999,26 +1015,6 @@ class WhisperModel(WhisperPreTrainedModel):
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import WhisperModel, WhisperFeatureExtractor
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> model = WhisperModel.from_pretrained("openai/whisper-base")
|
||||
>>> feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base")
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
|
||||
>>> input_features = inputs.input_features
|
||||
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
|
||||
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
|
||||
>>> list(last_hidden_state.shape)
|
||||
[1, 2, 512]
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
|
||||
@@ -26,6 +26,8 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
Speech2TextForConditionalGeneration,
|
||||
Wav2Vec2ForCTC,
|
||||
WhisperForConditionalGeneration,
|
||||
WhisperProcessor,
|
||||
)
|
||||
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
|
||||
from transformers.pipelines.audio_utils import chunk_bytes_iter
|
||||
@@ -308,6 +310,52 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
output = asr(data)
|
||||
self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."})
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
def test_simple_whisper_asr(self):
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="openai/whisper-tiny.en",
|
||||
framework="pt",
|
||||
)
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
filename = ds[0]["file"]
|
||||
output = speech_recognizer(filename)
|
||||
self.assertEqual(output, {"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to"})
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
def test_simple_whisper_translation(self):
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="openai/whisper-large",
|
||||
framework="pt",
|
||||
)
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
output = speech_recognizer(filename)
|
||||
self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."})
|
||||
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai/whisper-large")
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-large")
|
||||
|
||||
speech_recognizer_2 = AutomaticSpeechRecognitionPipeline(
|
||||
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
|
||||
)
|
||||
output_2 = speech_recognizer_2(filename)
|
||||
self.assertEqual(output, output_2)
|
||||
|
||||
processor = WhisperProcessor(feature_extractor, tokenizer)
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(task="transcribe", language="it")
|
||||
speech_translator = AutomaticSpeechRecognitionPipeline(
|
||||
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
|
||||
)
|
||||
output_3 = speech_translator(filename)
|
||||
self.assertEqual(output_3, {"text": " Un uomo ha detto allo universo, Sir, esiste."})
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
|
||||
@@ -178,8 +178,16 @@ class ANY:
|
||||
class PipelineTestCaseMeta(type):
|
||||
def __new__(mcs, name, bases, dct):
|
||||
def gen_test(ModelClass, checkpoint, tiny_config, tokenizer_class, feature_extractor_class):
|
||||
@skipIf(tiny_config is None, "TinyConfig does not exist")
|
||||
@skipIf(checkpoint is None, "checkpoint does not exist")
|
||||
@skipIf(
|
||||
tiny_config is None,
|
||||
"TinyConfig does not exist, make sure that you defined a `_CONFIG_FOR_DOC` variable in the modeling"
|
||||
" file",
|
||||
)
|
||||
@skipIf(
|
||||
checkpoint is None,
|
||||
"checkpoint does not exist, make sure that you defined a `_CHECKPOINT_FOR_DOC` variable in the"
|
||||
" modeling file",
|
||||
)
|
||||
def test(self):
|
||||
if ModelClass.__name__.endswith("ForCausalLM"):
|
||||
tiny_config.is_encoder_decoder = False
|
||||
|
||||
Reference in New Issue
Block a user