Fix Whisper Conversion Script: Correct decoder_attention_heads and _download function (#26834)
* Fix error in convert_openai_to_hf.py: "_download() missing 1 required positional argument: root" * Fix error in convert_openai_to_hf.py: "TypeError: byte indices must be integers or slices, not str" * Fix decoder_attention_heads value in convert_openai_to_hf.py. Correct the assignment for `decoder_attention_heads` in the conversion script for the Whisper model. * Black reformat convert_openai_to_hf.py file. * Fix Whisper model configuration defaults (for Tiny). - Correct encoder/decoder layers and attention heads count. - Update model width (`d_model`) to 384. * Add docstring to the convert_openai_to_hf.py script with a doctest * Add shebang and +x permission to the convert_openai_to_hf.py * convert_openai_to_hf.py: reuse the read model_bytes in the _download() function * Move convert_openai_to_hf.py doctest example to whisper.md * whisper.md: Add an inference example to the Conversion section. * whisper.md: remove `model.config.forced_decoder_ids` from examples (deprecated) * whisper.md: Remove "## Format Conversion" section; not used by users * whisper.md: Use librispeech_asr_dummy dataset and load_dataset()
This commit is contained in:
@@ -34,6 +34,42 @@ The original code can be found [here](https://github.com/openai/whisper).
|
|||||||
- Inference is currently only implemented for short-form i.e. audio is pre-segmented into <=30s segments. Long-form (including timestamps) will be implemented in a future release.
|
- Inference is currently only implemented for short-form i.e. audio is pre-segmented into <=30s segments. Long-form (including timestamps) will be implemented in a future release.
|
||||||
- One can use [`WhisperProcessor`] to prepare audio for the model, and decode the predicted ID's back into text.
|
- One can use [`WhisperProcessor`] to prepare audio for the model, and decode the predicted ID's back into text.
|
||||||
|
|
||||||
|
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ). The Tensorflow version of this model was contributed by [amyeroberts](https://huggingface.co/amyeroberts).
|
||||||
|
The original code can be found [here](https://github.com/openai/whisper).
|
||||||
|
|
||||||
|
## Inference
|
||||||
|
|
||||||
|
Here is a step-by-step guide to transcribing an audio sample using a pre-trained Whisper model:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from datasets import load_dataset
|
||||||
|
>>> from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
||||||
|
|
||||||
|
>>> # Select an audio file and read it:
|
||||||
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||||
|
>>> audio_sample = ds[0]["audio"]
|
||||||
|
>>> waveform = audio_sample["array"]
|
||||||
|
>>> sampling_rate = audio_sample["sampling_rate"]
|
||||||
|
|
||||||
|
>>> # Load the Whisper model in Hugging Face format:
|
||||||
|
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||||
|
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||||
|
|
||||||
|
>>> # Use the model and processor to transcribe the audio:
|
||||||
|
>>> input_features = processor(
|
||||||
|
... waveform, sampling_rate=sampling_rate, return_tensors="pt"
|
||||||
|
... ).input_features
|
||||||
|
|
||||||
|
>>> # Generate token ids
|
||||||
|
>>> predicted_ids = model.generate(input_features)
|
||||||
|
|
||||||
|
>>> # Decode token ids to text
|
||||||
|
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
>>> transcription[0]
|
||||||
|
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
|
||||||
|
```
|
||||||
|
|
||||||
## WhisperConfig
|
## WhisperConfig
|
||||||
|
|
||||||
[[autodoc]] WhisperConfig
|
[[autodoc]] WhisperConfig
|
||||||
|
|||||||
@@ -77,13 +77,13 @@ class WhisperConfig(PretrainedConfig):
|
|||||||
num_mel_bins (`int`, *optional*, defaults to 80):
|
num_mel_bins (`int`, *optional*, defaults to 80):
|
||||||
Number of mel features used per input features. Should correspond to the value used in the
|
Number of mel features used per input features. Should correspond to the value used in the
|
||||||
`WhisperProcessor` class.
|
`WhisperProcessor` class.
|
||||||
encoder_layers (`int`, *optional*, defaults to 6):
|
encoder_layers (`int`, *optional*, defaults to 4):
|
||||||
Number of encoder layers.
|
Number of encoder layers.
|
||||||
decoder_layers (`int`, *optional*, defaults to 6):
|
decoder_layers (`int`, *optional*, defaults to 4):
|
||||||
Number of decoder layers.
|
Number of decoder layers.
|
||||||
encoder_attention_heads (`int`, *optional*, defaults to 4):
|
encoder_attention_heads (`int`, *optional*, defaults to 6):
|
||||||
Number of attention heads for each attention layer in the Transformer encoder.
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
decoder_attention_heads (`int`, *optional*, defaults to 4):
|
decoder_attention_heads (`int`, *optional*, defaults to 6):
|
||||||
Number of attention heads for each attention layer in the Transformer decoder.
|
Number of attention heads for each attention layer in the Transformer decoder.
|
||||||
encoder_ffn_dim (`int`, *optional*, defaults to 1536):
|
encoder_ffn_dim (`int`, *optional*, defaults to 1536):
|
||||||
Dimensionality of the "intermediate" (often named feed-forward) layer in encoder.
|
Dimensionality of the "intermediate" (often named feed-forward) layer in encoder.
|
||||||
@@ -106,7 +106,7 @@ class WhisperConfig(PretrainedConfig):
|
|||||||
activation_function (`str`, *optional*, defaults to `"gelu"`):
|
activation_function (`str`, *optional*, defaults to `"gelu"`):
|
||||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||||
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
||||||
d_model (`int`, *optional*, defaults to 256):
|
d_model (`int`, *optional*, defaults to 384):
|
||||||
Dimensionality of the layers.
|
Dimensionality of the layers.
|
||||||
dropout (`float`, *optional*, defaults to 0.1):
|
dropout (`float`, *optional*, defaults to 0.1):
|
||||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||||
@@ -197,10 +197,10 @@ class WhisperConfig(PretrainedConfig):
|
|||||||
self,
|
self,
|
||||||
vocab_size=51865,
|
vocab_size=51865,
|
||||||
num_mel_bins=80,
|
num_mel_bins=80,
|
||||||
encoder_layers=6,
|
encoder_layers=4,
|
||||||
encoder_attention_heads=4,
|
encoder_attention_heads=6,
|
||||||
decoder_layers=6,
|
decoder_layers=4,
|
||||||
decoder_attention_heads=4,
|
decoder_attention_heads=6,
|
||||||
decoder_ffn_dim=1536,
|
decoder_ffn_dim=1536,
|
||||||
encoder_ffn_dim=1536,
|
encoder_ffn_dim=1536,
|
||||||
encoder_layerdrop=0.0,
|
encoder_layerdrop=0.0,
|
||||||
@@ -209,7 +209,7 @@ class WhisperConfig(PretrainedConfig):
|
|||||||
use_cache=True,
|
use_cache=True,
|
||||||
is_encoder_decoder=True,
|
is_encoder_decoder=True,
|
||||||
activation_function="gelu",
|
activation_function="gelu",
|
||||||
d_model=256,
|
d_model=384,
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
attention_dropout=0.0,
|
attention_dropout=0.0,
|
||||||
activation_dropout=0.0,
|
activation_dropout=0.0,
|
||||||
|
|||||||
14
src/transformers/models/whisper/convert_openai_to_hf.py
Normal file → Executable file
14
src/transformers/models/whisper/convert_openai_to_hf.py
Normal file → Executable file
@@ -1,3 +1,5 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""Converts a Whisper model in OpenAI format to Hugging Face format."""
|
||||||
# Copyright 2022 The HuggingFace Inc. team and the OpenAI team. All rights reserved.
|
# Copyright 2022 The HuggingFace Inc. team and the OpenAI team. All rights reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@@ -14,6 +16,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import io
|
||||||
import os
|
import os
|
||||||
import urllib
|
import urllib
|
||||||
import warnings
|
import warnings
|
||||||
@@ -90,7 +93,7 @@ def make_linear_from_emb(emb):
|
|||||||
return lin_layer
|
return lin_layer
|
||||||
|
|
||||||
|
|
||||||
def _download(url: str, root: str) -> bytes:
|
def _download(url: str, root: str) -> io.BytesIO:
|
||||||
os.makedirs(root, exist_ok=True)
|
os.makedirs(root, exist_ok=True)
|
||||||
filename = os.path.basename(url)
|
filename = os.path.basename(url)
|
||||||
|
|
||||||
@@ -103,7 +106,7 @@ def _download(url: str, root: str) -> bytes:
|
|||||||
if os.path.isfile(download_target):
|
if os.path.isfile(download_target):
|
||||||
model_bytes = open(download_target, "rb").read()
|
model_bytes = open(download_target, "rb").read()
|
||||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||||
return model_bytes
|
return torch.load(io.BytesIO(model_bytes))
|
||||||
else:
|
else:
|
||||||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||||
|
|
||||||
@@ -125,12 +128,13 @@ def _download(url: str, root: str) -> bytes:
|
|||||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_bytes
|
return torch.load(io.BytesIO(model_bytes))
|
||||||
|
|
||||||
|
|
||||||
def convert_openai_whisper_to_tfms(checkpoint_path, pytorch_dump_folder_path):
|
def convert_openai_whisper_to_tfms(checkpoint_path, pytorch_dump_folder_path):
|
||||||
if ".pt" not in checkpoint_path:
|
if ".pt" not in checkpoint_path:
|
||||||
original_checkpoint = _download(_MODELS[checkpoint_path])
|
root = os.path.dirname(pytorch_dump_folder_path) or "."
|
||||||
|
original_checkpoint = _download(_MODELS[checkpoint_path], root)
|
||||||
else:
|
else:
|
||||||
original_checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
original_checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||||
dimensions = original_checkpoint["dims"]
|
dimensions = original_checkpoint["dims"]
|
||||||
@@ -151,7 +155,7 @@ def convert_openai_whisper_to_tfms(checkpoint_path, pytorch_dump_folder_path):
|
|||||||
encoder_layers=dimensions["n_audio_layer"],
|
encoder_layers=dimensions["n_audio_layer"],
|
||||||
encoder_attention_heads=dimensions["n_audio_head"],
|
encoder_attention_heads=dimensions["n_audio_head"],
|
||||||
decoder_layers=dimensions["n_text_layer"],
|
decoder_layers=dimensions["n_text_layer"],
|
||||||
decoder_attention_heads=dimensions["n_text_state"],
|
decoder_attention_heads=dimensions["n_text_head"],
|
||||||
max_source_positions=dimensions["n_audio_ctx"],
|
max_source_positions=dimensions["n_audio_ctx"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user