[Whisper] Add conversion script for the tokenizer (#27338)
* draft * updates * full conversion taken from `https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee` * psuh * nits * updates * more nits * Add co author Co-authored-by: Joshua Lochner <admin@xenova.com> * fixup * cleanup * styling * add proper path * update * nits * don't push the exit * clean * update whisper doc * don't error out if tiktoken is not here * make sure we are BC with conversion * nit * Update docs/source/en/model_doc/whisper.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * merge and update * update markdwon * Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> --------- Co-authored-by: Joshua Lochner <admin@xenova.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -34,8 +34,13 @@ The original code can be found [here](https://github.com/openai/whisper).
|
|||||||
- Inference is currently only implemented for short-form i.e. audio is pre-segmented into <=30s segments. Long-form (including timestamps) will be implemented in a future release.
|
- Inference is currently only implemented for short-form i.e. audio is pre-segmented into <=30s segments. Long-form (including timestamps) will be implemented in a future release.
|
||||||
- One can use [`WhisperProcessor`] to prepare audio for the model, and decode the predicted ID's back into text.
|
- One can use [`WhisperProcessor`] to prepare audio for the model, and decode the predicted ID's back into text.
|
||||||
|
|
||||||
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ). The Tensorflow version of this model was contributed by [amyeroberts](https://huggingface.co/amyeroberts).
|
- To convert the tokenizer, we recommend using the following:
|
||||||
The original code can be found [here](https://github.com/openai/whisper).
|
|
||||||
|
```bash
|
||||||
|
python src/transformers/models/whisper/convert_openai_to_hf.py --checkpoint_path "" --pytorch_dump_folder_path "Arthur/whisper-3" --convert_tokenizer True --whisper_version 3 --multilingual True
|
||||||
|
```
|
||||||
|
Here the `whisper_version` will set the number of languages to `100` to account for `cantonese` which was added in `whisper-large-v3`.
|
||||||
|
|
||||||
|
|
||||||
## Inference
|
## Inference
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,9 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import hashlib
|
import hashlib
|
||||||
import io
|
import io
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
import urllib
|
import urllib
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@@ -25,7 +27,9 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from transformers import WhisperConfig, WhisperForConditionalGeneration
|
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperTokenizer
|
||||||
|
from transformers.models.whisper.tokenization_whisper import LANGUAGES, bytes_to_unicode
|
||||||
|
from transformers.utils.import_utils import _is_package_available
|
||||||
|
|
||||||
|
|
||||||
_MODELS = {
|
_MODELS = {
|
||||||
@@ -41,6 +45,11 @@ _MODELS = {
|
|||||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_TOKENIZERS = {
|
||||||
|
"multilingual": "https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/multilingual.tiktoken",
|
||||||
|
"english": "https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/gpt2.tiktoken",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def remove_ignore_keys_(state_dict):
|
def remove_ignore_keys_(state_dict):
|
||||||
ignore_keys = ["layers", "blocks"]
|
ignore_keys = ["layers", "blocks"]
|
||||||
@@ -178,11 +187,119 @@ def convert_openai_whisper_to_tfms(checkpoint_path, pytorch_dump_folder_path):
|
|||||||
model.save_pretrained(pytorch_dump_folder_path)
|
model.save_pretrained(pytorch_dump_folder_path)
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from https://github.com/openai/tiktoken/issues/60#issuecomment-1499977960
|
||||||
|
def _bpe(mergeable_ranks, token: bytes, max_rank=None) -> list[bytes]:
|
||||||
|
parts = [bytes([b]) for b in token]
|
||||||
|
while True:
|
||||||
|
min_idx = None
|
||||||
|
min_rank = None
|
||||||
|
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
|
||||||
|
rank = mergeable_ranks.get(pair[0] + pair[1])
|
||||||
|
if rank is not None and (min_rank is None or rank < min_rank):
|
||||||
|
min_idx = i
|
||||||
|
min_rank = rank
|
||||||
|
if min_rank is None or (max_rank is not None and min_rank >= max_rank):
|
||||||
|
break
|
||||||
|
assert min_idx is not None
|
||||||
|
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :]
|
||||||
|
return parts
|
||||||
|
|
||||||
|
|
||||||
|
def convert_tiktoken_bpe_to_hf(tiktoken_url: str):
|
||||||
|
bpe_ranks = load_tiktoken_bpe(tiktoken_url)
|
||||||
|
byte_encoder = bytes_to_unicode()
|
||||||
|
|
||||||
|
def token_bytes_to_string(b):
|
||||||
|
return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
|
||||||
|
|
||||||
|
merges = []
|
||||||
|
vocab = {}
|
||||||
|
for token, rank in bpe_ranks.items():
|
||||||
|
vocab[token_bytes_to_string(token)] = rank
|
||||||
|
if len(token) == 1:
|
||||||
|
continue
|
||||||
|
merged = tuple(_bpe(bpe_ranks, token, max_rank=rank))
|
||||||
|
if len(merged) == 2: # account for empty token
|
||||||
|
merges.append(" ".join(map(token_bytes_to_string, merged)))
|
||||||
|
return vocab, merges
|
||||||
|
|
||||||
|
|
||||||
|
def convert_tiktoken_to_hf(
|
||||||
|
pytorch_dump_folder_path: str, multilingual: bool = True, num_languages: int = 100, time_precision=0.02
|
||||||
|
) -> WhisperTokenizer:
|
||||||
|
# requires whisper, unless we use the path to the tiktoken file
|
||||||
|
tiktoken_tokenizer_path = _TOKENIZERS["multilingual" if multilingual else "english"]
|
||||||
|
start_of_transcript = ["<|endoftext|>", "<|startoftranscript|>"]
|
||||||
|
control_tokens = [
|
||||||
|
"<|translate|>",
|
||||||
|
"<|transcribe|>",
|
||||||
|
"<|startoflm|>",
|
||||||
|
"<|startofprev|>",
|
||||||
|
"<|nocaptions|>",
|
||||||
|
"<|notimestamps|>",
|
||||||
|
]
|
||||||
|
# these are special tokens, not normalized
|
||||||
|
language_tokens = [f"<|{k}|>" for k in list(LANGUAGES)[:num_languages]]
|
||||||
|
# These are not special but normalized
|
||||||
|
timestamp_tokens = [("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)]
|
||||||
|
|
||||||
|
vocab, merges = convert_tiktoken_bpe_to_hf(tiktoken_tokenizer_path)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
vocab_file = f"{tmpdirname}/vocab.json"
|
||||||
|
merge_file = f"{tmpdirname}/merges.txt"
|
||||||
|
with open(vocab_file, "w", encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
with open(merge_file, "w", encoding="utf-8") as writer:
|
||||||
|
writer.write("#version: 0.2\n")
|
||||||
|
for bpe_tokens in merges:
|
||||||
|
writer.write(bpe_tokens + "\n")
|
||||||
|
|
||||||
|
hf_tokenizer = WhisperTokenizer(vocab_file, merge_file)
|
||||||
|
|
||||||
|
hf_tokenizer.add_tokens(start_of_transcript + language_tokens + control_tokens, special_tokens=True)
|
||||||
|
hf_tokenizer.add_tokens(timestamp_tokens, special_tokens=False)
|
||||||
|
hf_tokenizer.save_pretrained(pytorch_dump_folder_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
# # Required parameters
|
# # Required parameters
|
||||||
parser.add_argument("--checkpoint_path", type=str, help="Patht to the downloaded checkpoints")
|
parser.add_argument("--checkpoint_path", type=str, help="Patht to the downloaded checkpoints")
|
||||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--convert_tokenizer",
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether or not the tokenizer should be converted along with the model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--whisper_version",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Version of the whisper release",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--multilingual",
|
||||||
|
type=bool,
|
||||||
|
default="store_true",
|
||||||
|
help="Whether or not the model is multilingual or english only",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.convert_tokenizer:
|
||||||
|
try:
|
||||||
|
if not _is_package_available("tiktoken"):
|
||||||
|
raise """`tiktoken` is not installed, use `pip install tiktoken` to convert the tokenizer"""
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
from tiktoken.load import load_tiktoken_bpe
|
||||||
|
|
||||||
|
NUM_LANGUAGES_PER_RELEASE = {1: 99, 2: 99, 3: 100}
|
||||||
|
convert_tiktoken_to_hf(
|
||||||
|
args.pytorch_dump_folder_path, args.multilingual, NUM_LANGUAGES_PER_RELEASE[args.whisper_version]
|
||||||
|
)
|
||||||
|
|
||||||
convert_openai_whisper_to_tfms(args.checkpoint_path, args.pytorch_dump_folder_path)
|
convert_openai_whisper_to_tfms(args.checkpoint_path, args.pytorch_dump_folder_path)
|
||||||
|
|||||||
@@ -191,6 +191,7 @@ LANGUAGES = {
|
|||||||
"ba": "bashkir",
|
"ba": "bashkir",
|
||||||
"jw": "javanese",
|
"jw": "javanese",
|
||||||
"su": "sundanese",
|
"su": "sundanese",
|
||||||
|
"yue": "cantonese",
|
||||||
}
|
}
|
||||||
|
|
||||||
# language code lookup by name, with a few language aliases
|
# language code lookup by name, with a few language aliases
|
||||||
@@ -207,6 +208,7 @@ TO_LANGUAGE_CODE = {
|
|||||||
"moldovan": "ro",
|
"moldovan": "ro",
|
||||||
"sinhalese": "si",
|
"sinhalese": "si",
|
||||||
"castilian": "es",
|
"castilian": "es",
|
||||||
|
"mandarin": "zh",
|
||||||
}
|
}
|
||||||
|
|
||||||
TASK_IDS = ["translate", "transcribe"]
|
TASK_IDS = ["translate", "transcribe"]
|
||||||
@@ -1206,7 +1208,7 @@ def _combine_tokens_into_words(
|
|||||||
if language is None:
|
if language is None:
|
||||||
language = "english"
|
language = "english"
|
||||||
|
|
||||||
if language in {"chinese", "japanese", "thai", "lao", "myanmar"}:
|
if language in {"chinese", "japanese", "thai", "lao", "myanmar", "cantonese"}:
|
||||||
# These languages don't typically use spaces.
|
# These languages don't typically use spaces.
|
||||||
words, word_tokens, token_indices = _split_tokens_on_unicode(tokenizer, tokens)
|
words, word_tokens, token_indices = _split_tokens_on_unicode(tokenizer, tokens)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user