[Whisper Tokenizer] Make more user-friendly (#19921)
* [Whisper Tokenizer] Make more user-friendly * use property * make indexing rigorous * small clean-up * tests * skip seq2seq tests * remove multilingual arg * reorder args * collapse to one function Co-authored-by: ArthurZucker <arthur@huggingface.co> * option to override attributes Co-authored-by: ArthurZucker <arthur@huggingface.co> * add to docs * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * make comment more clear Co-authored-by: sgugger <sylvain@huggingface.co> * don't add special tokens in get_decoder_prompt_ids * add test for set_prefix_tokens Co-authored-by: ArthurZucker <arthur@huggingface.co> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: sgugger <sylvain@huggingface.co>
This commit is contained in:
@@ -39,6 +39,7 @@ The original code can be found [here](https://github.com/openai/whisper).
|
|||||||
## WhisperTokenizer
|
## WhisperTokenizer
|
||||||
|
|
||||||
[[autodoc]] WhisperTokenizer
|
[[autodoc]] WhisperTokenizer
|
||||||
|
- set_prefix_tokens
|
||||||
- build_inputs_with_special_tokens
|
- build_inputs_with_special_tokens
|
||||||
- get_special_tokens_mask
|
- get_special_tokens_mask
|
||||||
- create_token_type_ids_from_sequences
|
- create_token_type_ids_from_sequences
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ class WhisperProcessor(ProcessorMixin):
|
|||||||
forced_decoder_tokens += f"<|{task}|>"
|
forced_decoder_tokens += f"<|{task}|>"
|
||||||
|
|
||||||
forced_decoder_tokens += "<|notimestamps|>" if no_timestamps else ""
|
forced_decoder_tokens += "<|notimestamps|>" if no_timestamps else ""
|
||||||
ids = self.tokenizer.encode(forced_decoder_tokens)
|
ids = self.tokenizer.encode(forced_decoder_tokens, add_special_tokens=False)
|
||||||
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(ids)]
|
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(ids)]
|
||||||
return forced_decoder_ids
|
return forced_decoder_ids
|
||||||
|
|
||||||
|
|||||||
@@ -89,9 +89,130 @@ def get_pairs(word):
|
|||||||
return pairs
|
return pairs
|
||||||
|
|
||||||
|
|
||||||
|
LANGUAGES = {
|
||||||
|
"en": "english",
|
||||||
|
"zh": "chinese",
|
||||||
|
"de": "german",
|
||||||
|
"es": "spanish",
|
||||||
|
"ru": "russian",
|
||||||
|
"ko": "korean",
|
||||||
|
"fr": "french",
|
||||||
|
"ja": "japanese",
|
||||||
|
"pt": "portuguese",
|
||||||
|
"tr": "turkish",
|
||||||
|
"pl": "polish",
|
||||||
|
"ca": "catalan",
|
||||||
|
"nl": "dutch",
|
||||||
|
"ar": "arabic",
|
||||||
|
"sv": "swedish",
|
||||||
|
"it": "italian",
|
||||||
|
"id": "indonesian",
|
||||||
|
"hi": "hindi",
|
||||||
|
"fi": "finnish",
|
||||||
|
"vi": "vietnamese",
|
||||||
|
"iw": "hebrew",
|
||||||
|
"uk": "ukrainian",
|
||||||
|
"el": "greek",
|
||||||
|
"ms": "malay",
|
||||||
|
"cs": "czech",
|
||||||
|
"ro": "romanian",
|
||||||
|
"da": "danish",
|
||||||
|
"hu": "hungarian",
|
||||||
|
"ta": "tamil",
|
||||||
|
"no": "norwegian",
|
||||||
|
"th": "thai",
|
||||||
|
"ur": "urdu",
|
||||||
|
"hr": "croatian",
|
||||||
|
"bg": "bulgarian",
|
||||||
|
"lt": "lithuanian",
|
||||||
|
"la": "latin",
|
||||||
|
"mi": "maori",
|
||||||
|
"ml": "malayalam",
|
||||||
|
"cy": "welsh",
|
||||||
|
"sk": "slovak",
|
||||||
|
"te": "telugu",
|
||||||
|
"fa": "persian",
|
||||||
|
"lv": "latvian",
|
||||||
|
"bn": "bengali",
|
||||||
|
"sr": "serbian",
|
||||||
|
"az": "azerbaijani",
|
||||||
|
"sl": "slovenian",
|
||||||
|
"kn": "kannada",
|
||||||
|
"et": "estonian",
|
||||||
|
"mk": "macedonian",
|
||||||
|
"br": "breton",
|
||||||
|
"eu": "basque",
|
||||||
|
"is": "icelandic",
|
||||||
|
"hy": "armenian",
|
||||||
|
"ne": "nepali",
|
||||||
|
"mn": "mongolian",
|
||||||
|
"bs": "bosnian",
|
||||||
|
"kk": "kazakh",
|
||||||
|
"sq": "albanian",
|
||||||
|
"sw": "swahili",
|
||||||
|
"gl": "galician",
|
||||||
|
"mr": "marathi",
|
||||||
|
"pa": "punjabi",
|
||||||
|
"si": "sinhala",
|
||||||
|
"km": "khmer",
|
||||||
|
"sn": "shona",
|
||||||
|
"yo": "yoruba",
|
||||||
|
"so": "somali",
|
||||||
|
"af": "afrikaans",
|
||||||
|
"oc": "occitan",
|
||||||
|
"ka": "georgian",
|
||||||
|
"be": "belarusian",
|
||||||
|
"tg": "tajik",
|
||||||
|
"sd": "sindhi",
|
||||||
|
"gu": "gujarati",
|
||||||
|
"am": "amharic",
|
||||||
|
"yi": "yiddish",
|
||||||
|
"lo": "lao",
|
||||||
|
"uz": "uzbek",
|
||||||
|
"fo": "faroese",
|
||||||
|
"ht": "haitian creole",
|
||||||
|
"ps": "pashto",
|
||||||
|
"tk": "turkmen",
|
||||||
|
"nn": "nynorsk",
|
||||||
|
"mt": "maltese",
|
||||||
|
"sa": "sanskrit",
|
||||||
|
"lb": "luxembourgish",
|
||||||
|
"my": "myanmar",
|
||||||
|
"bo": "tibetan",
|
||||||
|
"tl": "tagalog",
|
||||||
|
"mg": "malagasy",
|
||||||
|
"as": "assamese",
|
||||||
|
"tt": "tatar",
|
||||||
|
"haw": "hawaiian",
|
||||||
|
"ln": "lingala",
|
||||||
|
"ha": "hausa",
|
||||||
|
"ba": "bashkir",
|
||||||
|
"jw": "javanese",
|
||||||
|
"su": "sundanese",
|
||||||
|
}
|
||||||
|
|
||||||
|
# language code lookup by name, with a few language aliases
|
||||||
|
TO_LANGUAGE_CODE = {
|
||||||
|
**{language: code for code, language in LANGUAGES.items()},
|
||||||
|
"burmese": "my",
|
||||||
|
"valencian": "ca",
|
||||||
|
"flemish": "nl",
|
||||||
|
"haitian": "ht",
|
||||||
|
"letzeburgesch": "lb",
|
||||||
|
"pushto": "ps",
|
||||||
|
"panjabi": "pa",
|
||||||
|
"moldavian": "ro",
|
||||||
|
"moldovan": "ro",
|
||||||
|
"sinhalese": "si",
|
||||||
|
"castilian": "es",
|
||||||
|
}
|
||||||
|
|
||||||
|
TASK_IDS = ["translate", "transcribe"]
|
||||||
|
|
||||||
|
|
||||||
class WhisperTokenizer(PreTrainedTokenizer):
|
class WhisperTokenizer(PreTrainedTokenizer):
|
||||||
"""
|
"""
|
||||||
Construct an Whisper tokenizer.
|
Construct a Whisper tokenizer.
|
||||||
|
|
||||||
This tokenizer inherits from [`PreTrainedTokenizer`] which contains some of the main methods. Users should refer to
|
This tokenizer inherits from [`PreTrainedTokenizer`] which contains some of the main methods. Users should refer to
|
||||||
the superclass for more information regarding such methods.
|
the superclass for more information regarding such methods.
|
||||||
@@ -109,16 +230,22 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||||
token instead.
|
token instead.
|
||||||
bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
bos_token (`str`, *optional*, defaults to `"<|startoftranscript|>"`):
|
||||||
The beginning of sequence token.
|
The beginning of sequence token.
|
||||||
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
|
||||||
The end of sequence token.
|
The end of sequence token.
|
||||||
add_prefix_space (`bool`, *optional*, defaults to `False`):
|
add_prefix_space (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
|
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
|
||||||
other word.
|
other word.
|
||||||
add_bos_token (`bool`, *optional*, defaults to `False`):
|
language (`str`, *optional*):
|
||||||
Whether or not to add an initial <|endoftext|> to the input. This allows to treat the leading word just as
|
The language of the transcription text. The corresponding language id token is appended to the start of the
|
||||||
any other word.
|
sequence for multilingual speech recognition and speech translation tasks, e.g. for Spanish the token
|
||||||
|
`"<|es|>"` is appended to the start of sequence. This should be used for multilingual fine-tuning only.
|
||||||
|
task (`str`, *optional*):
|
||||||
|
Task identifier to append at the start of sequence (if any). This should be used for mulitlingual
|
||||||
|
fine-tuning, with `"transcribe"` for speech recognition and `"translate"` for speech translation.
|
||||||
|
predict_timestamps (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to omit the `<|notimestamps|>` token at the start of the sequence.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
vocab_files_names = VOCAB_FILES_NAMES
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
@@ -133,11 +260,13 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
normalizer_file=None,
|
normalizer_file=None,
|
||||||
errors="replace",
|
errors="replace",
|
||||||
unk_token="<|endoftext|>",
|
unk_token="<|endoftext|>",
|
||||||
bos_token="<|endoftext|>",
|
bos_token="<|startoftranscript|>",
|
||||||
eos_token="<|endoftext|>",
|
eos_token="<|endoftext|>",
|
||||||
pad_token=None,
|
pad_token=None,
|
||||||
add_prefix_space=False,
|
add_prefix_space=False,
|
||||||
add_bos_token=False,
|
language=None,
|
||||||
|
task=None,
|
||||||
|
predict_timestamps=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
|
||||||
@@ -152,10 +281,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
eos_token=eos_token,
|
eos_token=eos_token,
|
||||||
pad_token=pad_token,
|
pad_token=pad_token,
|
||||||
add_prefix_space=add_prefix_space,
|
add_prefix_space=add_prefix_space,
|
||||||
add_bos_token=add_bos_token,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.add_bos_token = add_bos_token
|
|
||||||
|
|
||||||
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
||||||
self.encoder = json.load(vocab_handle)
|
self.encoder = json.load(vocab_handle)
|
||||||
@@ -179,6 +306,10 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
||||||
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
||||||
|
|
||||||
|
self.language = language
|
||||||
|
self.task = task
|
||||||
|
self.predict_timestamps = predict_timestamps
|
||||||
|
|
||||||
def get_vocab(self):
|
def get_vocab(self):
|
||||||
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
||||||
vocab.update(self.added_tokens_encoder)
|
vocab.update(self.added_tokens_encoder)
|
||||||
@@ -231,27 +362,76 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
self.cache[token] = word
|
self.cache[token] = word
|
||||||
return word
|
return word
|
||||||
|
|
||||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.build_inputs_with_special_tokens with GPT2 -> Whisper
|
def set_prefix_tokens(self, language: str = None, task: str = None, predict_timestamps: bool = None):
|
||||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
"""
|
||||||
if self.add_bos_token:
|
Override the prefix tokens appended to the start of the label sequence. This method can be used standalone to
|
||||||
bos_token_ids = [self.bos_token_id]
|
update the prefix tokens as required when fine-tuning. Example:
|
||||||
else:
|
|
||||||
bos_token_ids = []
|
|
||||||
|
|
||||||
output = bos_token_ids + token_ids_0
|
```python
|
||||||
|
>>> # instantiate the tokenizer and set the prefix token to Spanish
|
||||||
|
>>> tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="spanish")
|
||||||
|
>>> # now switch the prefix token from Spanish to French
|
||||||
|
>>> tokenizer.set_prefix_tokens(language="french")
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
language (`str`, *optional*, defaults to `None`):
|
||||||
|
The language of the transcription text.
|
||||||
|
task (`str`, *optional*, defaults to `None`):
|
||||||
|
Task identifier to append at the start of sequence (if any).
|
||||||
|
predict_timestamps (`bool`, *optional*, defaults to `None`):
|
||||||
|
Whether to omit the `<|notimestamps|>` token at the start of the sequence.
|
||||||
|
"""
|
||||||
|
self.language = language if language is not None else self.language
|
||||||
|
self.task = task if task is not None else self.task
|
||||||
|
self.predict_timestamps = predict_timestamps if predict_timestamps is not None else self.predict_timestamps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prefix_tokens(self) -> List[int]:
|
||||||
|
all_special_ids = self.all_special_ids
|
||||||
|
bos_token_id = all_special_ids[-106]
|
||||||
|
translate_token_id = all_special_ids[-6]
|
||||||
|
transcribe_token_id = all_special_ids[-5]
|
||||||
|
notimestamps_token_id = all_special_ids[-1]
|
||||||
|
langs = tuple(LANGUAGES.keys())
|
||||||
|
|
||||||
|
if self.language is not None:
|
||||||
|
self.language = self.language.lower()
|
||||||
|
if self.language in TO_LANGUAGE_CODE:
|
||||||
|
language_id = TO_LANGUAGE_CODE[self.language]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported language: {self.language}. Language should be in: {TO_LANGUAGE_CODE.keys()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.task is not None:
|
||||||
|
if self.task not in TASK_IDS:
|
||||||
|
raise ValueError(f"Unsupported task: {self.task}. Task should be in: {TASK_IDS}")
|
||||||
|
|
||||||
|
bos_sequence = [bos_token_id]
|
||||||
|
if self.language is not None:
|
||||||
|
bos_sequence.append(bos_token_id + 1 + langs.index(language_id))
|
||||||
|
if self.task is not None:
|
||||||
|
bos_sequence.append(transcribe_token_id if self.task == "transcribe" else translate_token_id)
|
||||||
|
if not self.predict_timestamps:
|
||||||
|
bos_sequence.append(notimestamps_token_id)
|
||||||
|
return bos_sequence
|
||||||
|
|
||||||
|
# Copied from transformers.models.speech_to_text.tokenization_speech_to_text.Speech2TextTokenizer.build_inputs_with_special_tokens
|
||||||
|
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
||||||
|
"""Build model inputs from a sequence by appending eos_token_id."""
|
||||||
if token_ids_1 is None:
|
if token_ids_1 is None:
|
||||||
return output
|
return self.prefix_tokens + token_ids_0 + [self.eos_token_id]
|
||||||
|
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||||
|
return self.prefix_tokens + token_ids_0 + token_ids_1 + [self.eos_token_id]
|
||||||
|
|
||||||
return output + bos_token_ids + token_ids_1
|
# Copied from transformers.models.speech_to_text.tokenization_speech_to_text.Speech2TextTokenizer.get_special_tokens_mask
|
||||||
|
|
||||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_special_tokens_mask with GPT2 -> Whisper
|
|
||||||
def get_special_tokens_mask(
|
def get_special_tokens_mask(
|
||||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||||
special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
|
special tokens using the tokenizer `prepare_for_model` method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token_ids_0 (`List[int]`):
|
token_ids_0 (`List[int]`):
|
||||||
@@ -264,19 +444,17 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
Returns:
|
Returns:
|
||||||
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if already_has_special_tokens:
|
if already_has_special_tokens:
|
||||||
return super().get_special_tokens_mask(
|
return super().get_special_tokens_mask(
|
||||||
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.add_bos_token:
|
prefix_ones = [1] * len(self.prefix_tokens)
|
||||||
return super().get_special_tokens_mask(
|
suffix_ones = [1]
|
||||||
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=False
|
|
||||||
)
|
|
||||||
|
|
||||||
if token_ids_1 is None:
|
if token_ids_1 is None:
|
||||||
return [1] + ([0] * len(token_ids_0))
|
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
|
||||||
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
|
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
|
||||||
|
|
||||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize with GPT2 -> Whisper
|
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize with GPT2 -> Whisper
|
||||||
def _tokenize(self, text):
|
def _tokenize(self, text):
|
||||||
|
|||||||
@@ -20,14 +20,20 @@ from transformers.testing_utils import slow
|
|||||||
from ...test_tokenization_common import TokenizerTesterMixin
|
from ...test_tokenization_common import TokenizerTesterMixin
|
||||||
|
|
||||||
|
|
||||||
EN_CODE = 50258
|
ES_CODE = 50262
|
||||||
ES_CODE = 50256
|
EN_CODE = 50259
|
||||||
|
END_OF_TRANSCRIPT = 50257
|
||||||
|
START_OF_TRANSCRIPT = 50258
|
||||||
|
TRANSLATE = 50358
|
||||||
|
TRANSCRIBE = 50359
|
||||||
|
NOTIMESTAMPS = 50363
|
||||||
|
|
||||||
|
|
||||||
class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||||
tokenizer_class = WhisperTokenizer
|
tokenizer_class = WhisperTokenizer
|
||||||
test_rust_tokenizer = False
|
test_rust_tokenizer = False
|
||||||
test_sentencepiece = False
|
test_sentencepiece = False
|
||||||
|
test_seq2seq = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
@@ -101,13 +107,6 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||||
checkpoint_name = "openai/whisper-small.en"
|
checkpoint_name = "openai/whisper-small.en"
|
||||||
|
|
||||||
transcript = (
|
|
||||||
"'<|startoftranscript|> <|en|> <|transcribe|> <|notimestamps|> Nor is Mr. Quilters manner less interesting"
|
|
||||||
" than his matter.<|endoftext|>'"
|
|
||||||
)
|
|
||||||
clean_transcript = " Nor is Mr. Quilters manner less interesting than his matter."
|
|
||||||
french_text = "Bonjour! Il me semble que Mrs Quilters n'était pas présente"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.tokenizer: WhisperTokenizer = WhisperTokenizer.from_pretrained(cls.checkpoint_name)
|
cls.tokenizer: WhisperTokenizer = WhisperTokenizer.from_pretrained(cls.checkpoint_name)
|
||||||
@@ -115,15 +114,15 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_tokenizer_equivalence(self):
|
def test_tokenizer_equivalence(self):
|
||||||
text = "다람쥐 헌 쳇바퀴에 타고파"
|
text = "다람쥐 헌 쳇바퀴에 타고파"
|
||||||
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="ko")
|
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="korean")
|
||||||
gpt2_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
|
monolingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
|
||||||
|
|
||||||
gpt2_tokens = gpt2_tokenizer.encode(text)
|
monolingual_tokens = monolingual_tokenizer.encode(text, add_special_tokens=False)
|
||||||
multilingual_tokens = multilingual_tokenizer.encode(text)
|
multilingual_tokens = multilingual_tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
|
||||||
assert gpt2_tokenizer.decode(gpt2_tokens) == text
|
assert monolingual_tokenizer.decode(monolingual_tokens) == text
|
||||||
assert multilingual_tokenizer.decode(multilingual_tokens) == text
|
assert multilingual_tokenizer.decode(multilingual_tokens) == text
|
||||||
assert len(gpt2_tokens) > len(multilingual_tokens)
|
assert len(monolingual_tokens) > len(multilingual_tokens)
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
EXPECTED_ENG = [
|
EXPECTED_ENG = [
|
||||||
@@ -138,35 +137,42 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
self.assertListEqual(gpt2_tokens, EXPECTED_ENG)
|
self.assertListEqual(monolingual_tokens, EXPECTED_ENG)
|
||||||
self.assertListEqual(multilingual_tokens, EXPECTED_MULTI)
|
self.assertListEqual(multilingual_tokens, EXPECTED_MULTI)
|
||||||
|
|
||||||
def test_tokenizer_special(self):
|
def test_tokenizer_special(self):
|
||||||
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
|
multilingual_tokenizer = WhisperTokenizer.from_pretrained(
|
||||||
text = "<|startoftranscript|>Hey! How are you feeling? J'ai l'impression que 郷さん est prêt<|endoftext|>"
|
"openai/whisper-tiny", language="english", task="transcribe"
|
||||||
|
)
|
||||||
|
text = "Hey! How are you feeling? J'ai l'impression que 郷さん est prêt"
|
||||||
|
|
||||||
multilingual_tokens = multilingual_tokenizer.encode(text)
|
multilingual_tokens = multilingual_tokenizer.encode(text)
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
|
# format: <|startoftranscript|> <|lang-id|> <|task|> <|notimestamps|> ... transcription ids ... <|endoftext|>
|
||||||
EXPECTED_MULTI = [
|
EXPECTED_MULTI = [
|
||||||
50257, 10814, 0, 1374, 389, 345, 4203, 30, 449, 6,
|
START_OF_TRANSCRIPT, EN_CODE, TRANSCRIBE, NOTIMESTAMPS, 7057, 0, 1012, 366, 291,
|
||||||
1872, 300, 6, 11011, 2234, 8358, 16268, 225, 115, 43357,
|
2633, 30, 508, 6, 1301, 287, 6, 36107, 631, 220, 11178,
|
||||||
22174, 1556, 778, 25792, 83, 50256
|
115, 15567, 871, 44393, END_OF_TRANSCRIPT
|
||||||
]
|
]
|
||||||
|
EXPECTED_SPECIAL_TEXT = (
|
||||||
|
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|>Hey! How are you feeling? "
|
||||||
|
"J'ai l'impression que 郷さん est prêt<|endoftext|>"
|
||||||
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
self.assertListEqual(multilingual_tokens, EXPECTED_MULTI)
|
self.assertListEqual(multilingual_tokens, EXPECTED_MULTI)
|
||||||
|
|
||||||
self.assertEqual(text, multilingual_tokenizer.decode(multilingual_tokens))
|
special_transcript = multilingual_tokenizer.decode(multilingual_tokens, skip_special_tokens=False)
|
||||||
|
self.assertEqual(special_transcript, EXPECTED_SPECIAL_TEXT)
|
||||||
|
|
||||||
transcript = multilingual_tokenizer.decode(multilingual_tokens, skip_special_tokens=True)
|
transcript = multilingual_tokenizer.decode(multilingual_tokens, skip_special_tokens=True)
|
||||||
|
self.assertEqual(transcript, text)
|
||||||
EXPECTED_JAP = "Hey! How are you feeling? J'ai l'impression que 郷さん est prêt"
|
|
||||||
self.assertEqual(transcript, EXPECTED_JAP)
|
|
||||||
|
|
||||||
def test_vocab_size(self):
|
def test_vocab_size(self):
|
||||||
self.assertEqual(self.tokenizer.vocab_size, 50257)
|
self.assertEqual(self.tokenizer.vocab_size, 50257)
|
||||||
|
|
||||||
|
# Copied from transformers.tests.speech_to_test.test_tokenization_speech_to_text.py
|
||||||
def test_tokenizer_decode_ignores_language_codes(self):
|
def test_tokenizer_decode_ignores_language_codes(self):
|
||||||
self.assertIn(ES_CODE, self.tokenizer.all_special_ids)
|
self.assertIn(ES_CODE, self.tokenizer.all_special_ids)
|
||||||
generated_ids = [ES_CODE, 4, 1601, 47, 7647, 2]
|
generated_ids = [ES_CODE, 4, 1601, 47, 7647, 2]
|
||||||
@@ -176,15 +182,48 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
|||||||
self.assertNotIn(self.tokenizer.eos_token, result)
|
self.assertNotIn(self.tokenizer.eos_token, result)
|
||||||
|
|
||||||
def test_batch_encoding(self):
|
def test_batch_encoding(self):
|
||||||
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
|
multilingual_tokenizer = WhisperTokenizer.from_pretrained(
|
||||||
batch = ["<|en|><|notimestamps|>", "<|en|><|notimestamps|>I am sure that"]
|
"openai/whisper-tiny", language="spanish", task="translate"
|
||||||
|
)
|
||||||
|
batch = ["El gato ", "El gato se sentó"]
|
||||||
batch_output = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
|
batch_output = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
EXPECTED_MULTI = [
|
EXPECTED_MULTI = [
|
||||||
[50258, 50362, 50256, 50256, 50256, 50256],
|
[START_OF_TRANSCRIPT, ES_CODE, TRANSLATE, NOTIMESTAMPS, 17356, 290, 2513, 220,
|
||||||
[50258, 50362, 40, 716, 1654, 326]
|
END_OF_TRANSCRIPT, END_OF_TRANSCRIPT, END_OF_TRANSCRIPT],
|
||||||
|
[START_OF_TRANSCRIPT, ES_CODE, TRANSLATE, NOTIMESTAMPS, 17356, 290, 2513, 369,
|
||||||
|
2279, 812, END_OF_TRANSCRIPT]
|
||||||
]
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
self.assertListEqual(batch_output, EXPECTED_MULTI)
|
self.assertListEqual(batch_output, EXPECTED_MULTI)
|
||||||
|
|
||||||
|
def test_set_prefix_tokens(self):
|
||||||
|
multilingual_tokenizer = WhisperTokenizer.from_pretrained(
|
||||||
|
"openai/whisper-tiny", language="spanish", task="translate"
|
||||||
|
)
|
||||||
|
|
||||||
|
# change the language prefix token from Spanish to English
|
||||||
|
multilingual_tokenizer.set_prefix_tokens(language="english")
|
||||||
|
|
||||||
|
batch = ["the cat", "the cat sat"]
|
||||||
|
batch_output = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
EXPECTED_MULTI = [
|
||||||
|
[START_OF_TRANSCRIPT, EN_CODE, TRANSLATE, NOTIMESTAMPS, 3322, 3857,
|
||||||
|
END_OF_TRANSCRIPT, END_OF_TRANSCRIPT],
|
||||||
|
[START_OF_TRANSCRIPT, EN_CODE, TRANSLATE, NOTIMESTAMPS, 3322, 3857,
|
||||||
|
3227, END_OF_TRANSCRIPT]
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
self.assertListEqual(batch_output, EXPECTED_MULTI)
|
||||||
|
|
||||||
|
def test_batch_encoding_decoding(self):
|
||||||
|
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="spanish")
|
||||||
|
batch = ["hola güey", "que onda"]
|
||||||
|
batch_encoding = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
|
||||||
|
transcription = multilingual_tokenizer.batch_decode(batch_encoding, skip_special_tokens=True)
|
||||||
|
self.assertListEqual(batch, transcription)
|
||||||
|
|||||||
Reference in New Issue
Block a user