Fix whisper for pipeline (#19482)
* update feature extractor params * update attention mask handling * fix doc and pipeline test * add warning when skipping test * add whisper translation and transcription test * fix build doc test
This commit is contained in:
@@ -218,6 +218,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
return_attention_mask: Optional[bool] = None,
|
return_attention_mask: Optional[bool] = None,
|
||||||
padding: Optional[str] = "max_length",
|
padding: Optional[str] = "max_length",
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
|
sampling_rate: Optional[int] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
"""
|
"""
|
||||||
@@ -261,6 +262,19 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
The value that is used to fill the padding values / vectors.
|
The value that is used to fill the padding values / vectors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if sampling_rate is not None:
|
||||||
|
if sampling_rate != self.sampling_rate:
|
||||||
|
raise ValueError(
|
||||||
|
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
|
||||||
|
f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
|
||||||
|
f" {self.sampling_rate} and not {sampling_rate}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"It is strongly recommended to pass the `sampling_rate` argument to this function. "
|
||||||
|
"Failing to do so can result in silent errors that might be hard to debug."
|
||||||
|
)
|
||||||
|
|
||||||
is_batched = bool(
|
is_batched = bool(
|
||||||
isinstance(raw_speech, (list, tuple))
|
isinstance(raw_speech, (list, tuple))
|
||||||
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
|
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
|
||||||
|
|||||||
@@ -31,13 +31,22 @@ from ...modeling_outputs import (
|
|||||||
Seq2SeqModelOutput,
|
Seq2SeqModelOutput,
|
||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
from ...utils import (
|
||||||
|
add_code_sample_docstrings,
|
||||||
|
add_start_docstrings,
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
logging,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
from .configuration_whisper import WhisperConfig
|
from .configuration_whisper import WhisperConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "WhisperConfig"
|
_CONFIG_FOR_DOC = "WhisperConfig"
|
||||||
|
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
|
||||||
|
_PROCESSOR_FOR_DOC = "openai/whisper-tiny"
|
||||||
|
_EXPECTED_OUTPUT_SHAPE = [1, 2, 512]
|
||||||
|
|
||||||
|
|
||||||
WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
@@ -982,7 +991,14 @@ class WhisperModel(WhisperPreTrainedModel):
|
|||||||
return self.decoder
|
return self.decoder
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
@add_code_sample_docstrings(
|
||||||
|
processor_class=_PROCESSOR_FOR_DOC,
|
||||||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
|
output_type=Seq2SeqModelOutput,
|
||||||
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||||
|
modality="audio",
|
||||||
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_features=None,
|
input_features=None,
|
||||||
@@ -999,26 +1015,6 @@ class WhisperModel(WhisperPreTrainedModel):
|
|||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
):
|
):
|
||||||
r"""
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> import torch
|
|
||||||
>>> from transformers import WhisperModel, WhisperFeatureExtractor
|
|
||||||
>>> from datasets import load_dataset
|
|
||||||
|
|
||||||
>>> model = WhisperModel.from_pretrained("openai/whisper-base")
|
|
||||||
>>> feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base")
|
|
||||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
|
||||||
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
|
|
||||||
>>> input_features = inputs.input_features
|
|
||||||
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
|
|
||||||
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
|
|
||||||
>>> list(last_hidden_state.shape)
|
|
||||||
[1, 2, 512]
|
|
||||||
```"""
|
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
Speech2TextForConditionalGeneration,
|
Speech2TextForConditionalGeneration,
|
||||||
Wav2Vec2ForCTC,
|
Wav2Vec2ForCTC,
|
||||||
|
WhisperForConditionalGeneration,
|
||||||
|
WhisperProcessor,
|
||||||
)
|
)
|
||||||
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
|
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
|
||||||
from transformers.pipelines.audio_utils import chunk_bytes_iter
|
from transformers.pipelines.audio_utils import chunk_bytes_iter
|
||||||
@@ -308,6 +310,52 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
output = asr(data)
|
output = asr(data)
|
||||||
self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."})
|
self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."})
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
@require_torchaudio
|
||||||
|
def test_simple_whisper_asr(self):
|
||||||
|
speech_recognizer = pipeline(
|
||||||
|
task="automatic-speech-recognition",
|
||||||
|
model="openai/whisper-tiny.en",
|
||||||
|
framework="pt",
|
||||||
|
)
|
||||||
|
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||||
|
filename = ds[0]["file"]
|
||||||
|
output = speech_recognizer(filename)
|
||||||
|
self.assertEqual(output, {"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to"})
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
@require_torchaudio
|
||||||
|
def test_simple_whisper_translation(self):
|
||||||
|
speech_recognizer = pipeline(
|
||||||
|
task="automatic-speech-recognition",
|
||||||
|
model="openai/whisper-large",
|
||||||
|
framework="pt",
|
||||||
|
)
|
||||||
|
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||||
|
filename = ds[40]["file"]
|
||||||
|
output = speech_recognizer(filename)
|
||||||
|
self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."})
|
||||||
|
|
||||||
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("openai/whisper-large")
|
||||||
|
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-large")
|
||||||
|
|
||||||
|
speech_recognizer_2 = AutomaticSpeechRecognitionPipeline(
|
||||||
|
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
|
||||||
|
)
|
||||||
|
output_2 = speech_recognizer_2(filename)
|
||||||
|
self.assertEqual(output, output_2)
|
||||||
|
|
||||||
|
processor = WhisperProcessor(feature_extractor, tokenizer)
|
||||||
|
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(task="transcribe", language="it")
|
||||||
|
speech_translator = AutomaticSpeechRecognitionPipeline(
|
||||||
|
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
|
||||||
|
)
|
||||||
|
output_3 = speech_translator(filename)
|
||||||
|
self.assertEqual(output_3, {"text": " Un uomo ha detto allo universo, Sir, esiste."})
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_torchaudio
|
@require_torchaudio
|
||||||
|
|||||||
@@ -178,8 +178,16 @@ class ANY:
|
|||||||
class PipelineTestCaseMeta(type):
|
class PipelineTestCaseMeta(type):
|
||||||
def __new__(mcs, name, bases, dct):
|
def __new__(mcs, name, bases, dct):
|
||||||
def gen_test(ModelClass, checkpoint, tiny_config, tokenizer_class, feature_extractor_class):
|
def gen_test(ModelClass, checkpoint, tiny_config, tokenizer_class, feature_extractor_class):
|
||||||
@skipIf(tiny_config is None, "TinyConfig does not exist")
|
@skipIf(
|
||||||
@skipIf(checkpoint is None, "checkpoint does not exist")
|
tiny_config is None,
|
||||||
|
"TinyConfig does not exist, make sure that you defined a `_CONFIG_FOR_DOC` variable in the modeling"
|
||||||
|
" file",
|
||||||
|
)
|
||||||
|
@skipIf(
|
||||||
|
checkpoint is None,
|
||||||
|
"checkpoint does not exist, make sure that you defined a `_CHECKPOINT_FOR_DOC` variable in the"
|
||||||
|
" modeling file",
|
||||||
|
)
|
||||||
def test(self):
|
def test(self):
|
||||||
if ModelClass.__name__.endswith("ForCausalLM"):
|
if ModelClass.__name__.endswith("ForCausalLM"):
|
||||||
tiny_config.is_encoder_decoder = False
|
tiny_config.is_encoder_decoder = False
|
||||||
|
|||||||
Reference in New Issue
Block a user