From 06d488061f065f3a8e8f709bce43f3f00eb49052 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Thu, 3 Nov 2022 14:22:40 +0000 Subject: [PATCH] [Whisper Tokenizer] Make more user-friendly (#19921) * [Whisper Tokenizer] Make more user-friendly * use property * make indexing rigorous * small clean-up * tests * skip seq2seq tests * remove multilingual arg * reorder args * collapse to one function Co-authored-by: ArthurZucker * option to override attributes Co-authored-by: ArthurZucker * add to docs * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * make comment more clear Co-authored-by: sgugger * don't add special tokens in get_decoder_prompt_ids * add test for set_prefix_tokens Co-authored-by: ArthurZucker Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: sgugger --- docs/source/en/model_doc/whisper.mdx | 1 + .../models/whisper/processing_whisper.py | 2 +- .../models/whisper/tokenization_whisper.py | 236 +++++++++++++++--- .../whisper/test_tokenization_whisper.py | 97 ++++--- 4 files changed, 277 insertions(+), 59 deletions(-) diff --git a/docs/source/en/model_doc/whisper.mdx b/docs/source/en/model_doc/whisper.mdx index 6e88651d7e..29aab60e3f 100644 --- a/docs/source/en/model_doc/whisper.mdx +++ b/docs/source/en/model_doc/whisper.mdx @@ -39,6 +39,7 @@ The original code can be found [here](https://github.com/openai/whisper). ## WhisperTokenizer [[autodoc]] WhisperTokenizer + - set_prefix_tokens - build_inputs_with_special_tokens - get_special_tokens_mask - create_token_type_ids_from_sequences diff --git a/src/transformers/models/whisper/processing_whisper.py b/src/transformers/models/whisper/processing_whisper.py index c09e14e5ea..7f0e5a0292 100644 --- a/src/transformers/models/whisper/processing_whisper.py +++ b/src/transformers/models/whisper/processing_whisper.py @@ -70,7 +70,7 @@ class WhisperProcessor(ProcessorMixin): forced_decoder_tokens += f"<|{task}|>" forced_decoder_tokens += "<|notimestamps|>" if no_timestamps else "" - ids = self.tokenizer.encode(forced_decoder_tokens) + ids = self.tokenizer.encode(forced_decoder_tokens, add_special_tokens=False) forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(ids)] return forced_decoder_ids diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 696aa4f4e5..22c33a7a3f 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -89,9 +89,130 @@ def get_pairs(word): return pairs +LANGUAGES = { + "en": "english", + "zh": "chinese", + "de": "german", + "es": "spanish", + "ru": "russian", + "ko": "korean", + "fr": "french", + "ja": "japanese", + "pt": "portuguese", + "tr": "turkish", + "pl": "polish", + "ca": "catalan", + "nl": "dutch", + "ar": "arabic", + "sv": "swedish", + "it": "italian", + "id": "indonesian", + "hi": "hindi", + "fi": "finnish", + "vi": "vietnamese", + "iw": "hebrew", + "uk": "ukrainian", + "el": "greek", + "ms": "malay", + "cs": "czech", + "ro": "romanian", + "da": "danish", + "hu": "hungarian", + "ta": "tamil", + "no": "norwegian", + "th": "thai", + "ur": "urdu", + "hr": "croatian", + "bg": "bulgarian", + "lt": "lithuanian", + "la": "latin", + "mi": "maori", + "ml": "malayalam", + "cy": "welsh", + "sk": "slovak", + "te": "telugu", + "fa": "persian", + "lv": "latvian", + "bn": "bengali", + "sr": "serbian", + "az": "azerbaijani", + "sl": "slovenian", + "kn": "kannada", + "et": "estonian", + "mk": "macedonian", + "br": "breton", + "eu": "basque", + "is": "icelandic", + "hy": "armenian", + "ne": "nepali", + "mn": "mongolian", + "bs": "bosnian", + "kk": "kazakh", + "sq": "albanian", + "sw": "swahili", + "gl": "galician", + "mr": "marathi", + "pa": "punjabi", + "si": "sinhala", + "km": "khmer", + "sn": "shona", + "yo": "yoruba", + "so": "somali", + "af": "afrikaans", + "oc": "occitan", + "ka": "georgian", + "be": "belarusian", + "tg": "tajik", + "sd": "sindhi", + "gu": "gujarati", + "am": "amharic", + "yi": "yiddish", + "lo": "lao", + "uz": "uzbek", + "fo": "faroese", + "ht": "haitian creole", + "ps": "pashto", + "tk": "turkmen", + "nn": "nynorsk", + "mt": "maltese", + "sa": "sanskrit", + "lb": "luxembourgish", + "my": "myanmar", + "bo": "tibetan", + "tl": "tagalog", + "mg": "malagasy", + "as": "assamese", + "tt": "tatar", + "haw": "hawaiian", + "ln": "lingala", + "ha": "hausa", + "ba": "bashkir", + "jw": "javanese", + "su": "sundanese", +} + +# language code lookup by name, with a few language aliases +TO_LANGUAGE_CODE = { + **{language: code for code, language in LANGUAGES.items()}, + "burmese": "my", + "valencian": "ca", + "flemish": "nl", + "haitian": "ht", + "letzeburgesch": "lb", + "pushto": "ps", + "panjabi": "pa", + "moldavian": "ro", + "moldovan": "ro", + "sinhalese": "si", + "castilian": "es", +} + +TASK_IDS = ["translate", "transcribe"] + + class WhisperTokenizer(PreTrainedTokenizer): """ - Construct an Whisper tokenizer. + Construct a Whisper tokenizer. This tokenizer inherits from [`PreTrainedTokenizer`] which contains some of the main methods. Users should refer to the superclass for more information regarding such methods. @@ -109,16 +230,22 @@ class WhisperTokenizer(PreTrainedTokenizer): unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead. - bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + bos_token (`str`, *optional*, defaults to `"<|startoftranscript|>"`): The beginning of sequence token. eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): The end of sequence token. add_prefix_space (`bool`, *optional*, defaults to `False`): Whether or not to add an initial space to the input. This allows to treat the leading word just as any other word. - add_bos_token (`bool`, *optional*, defaults to `False`): - Whether or not to add an initial <|endoftext|> to the input. This allows to treat the leading word just as - any other word. + language (`str`, *optional*): + The language of the transcription text. The corresponding language id token is appended to the start of the + sequence for multilingual speech recognition and speech translation tasks, e.g. for Spanish the token + `"<|es|>"` is appended to the start of sequence. This should be used for multilingual fine-tuning only. + task (`str`, *optional*): + Task identifier to append at the start of sequence (if any). This should be used for mulitlingual + fine-tuning, with `"transcribe"` for speech recognition and `"translate"` for speech translation. + predict_timestamps (`bool`, *optional*, defaults to `False`): + Whether to omit the `<|notimestamps|>` token at the start of the sequence. """ vocab_files_names = VOCAB_FILES_NAMES @@ -133,11 +260,13 @@ class WhisperTokenizer(PreTrainedTokenizer): normalizer_file=None, errors="replace", unk_token="<|endoftext|>", - bos_token="<|endoftext|>", + bos_token="<|startoftranscript|>", eos_token="<|endoftext|>", pad_token=None, add_prefix_space=False, - add_bos_token=False, + language=None, + task=None, + predict_timestamps=False, **kwargs ): @@ -152,10 +281,8 @@ class WhisperTokenizer(PreTrainedTokenizer): eos_token=eos_token, pad_token=pad_token, add_prefix_space=add_prefix_space, - add_bos_token=add_bos_token, **kwargs, ) - self.add_bos_token = add_bos_token with open(vocab_file, encoding="utf-8") as vocab_handle: self.encoder = json.load(vocab_handle) @@ -179,6 +306,10 @@ class WhisperTokenizer(PreTrainedTokenizer): # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + self.language = language + self.task = task + self.predict_timestamps = predict_timestamps + def get_vocab(self): vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} vocab.update(self.added_tokens_encoder) @@ -231,27 +362,76 @@ class WhisperTokenizer(PreTrainedTokenizer): self.cache[token] = word return word - # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.build_inputs_with_special_tokens with GPT2 -> Whisper - def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): - if self.add_bos_token: - bos_token_ids = [self.bos_token_id] - else: - bos_token_ids = [] + def set_prefix_tokens(self, language: str = None, task: str = None, predict_timestamps: bool = None): + """ + Override the prefix tokens appended to the start of the label sequence. This method can be used standalone to + update the prefix tokens as required when fine-tuning. Example: - output = bos_token_ids + token_ids_0 + ```python + >>> # instantiate the tokenizer and set the prefix token to Spanish + >>> tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="spanish") + >>> # now switch the prefix token from Spanish to French + >>> tokenizer.set_prefix_tokens(language="french") + ``` + Args: + language (`str`, *optional*, defaults to `None`): + The language of the transcription text. + task (`str`, *optional*, defaults to `None`): + Task identifier to append at the start of sequence (if any). + predict_timestamps (`bool`, *optional*, defaults to `None`): + Whether to omit the `<|notimestamps|>` token at the start of the sequence. + """ + self.language = language if language is not None else self.language + self.task = task if task is not None else self.task + self.predict_timestamps = predict_timestamps if predict_timestamps is not None else self.predict_timestamps + + @property + def prefix_tokens(self) -> List[int]: + all_special_ids = self.all_special_ids + bos_token_id = all_special_ids[-106] + translate_token_id = all_special_ids[-6] + transcribe_token_id = all_special_ids[-5] + notimestamps_token_id = all_special_ids[-1] + langs = tuple(LANGUAGES.keys()) + + if self.language is not None: + self.language = self.language.lower() + if self.language in TO_LANGUAGE_CODE: + language_id = TO_LANGUAGE_CODE[self.language] + else: + raise ValueError( + f"Unsupported language: {self.language}. Language should be in: {TO_LANGUAGE_CODE.keys()}" + ) + + if self.task is not None: + if self.task not in TASK_IDS: + raise ValueError(f"Unsupported task: {self.task}. Task should be in: {TASK_IDS}") + + bos_sequence = [bos_token_id] + if self.language is not None: + bos_sequence.append(bos_token_id + 1 + langs.index(language_id)) + if self.task is not None: + bos_sequence.append(transcribe_token_id if self.task == "transcribe" else translate_token_id) + if not self.predict_timestamps: + bos_sequence.append(notimestamps_token_id) + return bos_sequence + + # Copied from transformers.models.speech_to_text.tokenization_speech_to_text.Speech2TextTokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]: + """Build model inputs from a sequence by appending eos_token_id.""" if token_ids_1 is None: - return output + return self.prefix_tokens + token_ids_0 + [self.eos_token_id] + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + [self.eos_token_id] - return output + bos_token_ids + token_ids_1 - - # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_special_tokens_mask with GPT2 -> Whisper + # Copied from transformers.models.speech_to_text.tokenization_speech_to_text.Speech2TextTokenizer.get_special_tokens_mask def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False ) -> List[int]: """ - Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding - special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. Args: token_ids_0 (`List[int]`): @@ -264,19 +444,17 @@ class WhisperTokenizer(PreTrainedTokenizer): Returns: `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. """ + if already_has_special_tokens: return super().get_special_tokens_mask( token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True ) - if not self.add_bos_token: - return super().get_special_tokens_mask( - token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=False - ) - + prefix_ones = [1] * len(self.prefix_tokens) + suffix_ones = [1] if token_ids_1 is None: - return [1] + ([0] * len(token_ids_0)) - return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones + return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize with GPT2 -> Whisper def _tokenize(self, text): diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py index 4dc66a4991..d01c41c0ae 100644 --- a/tests/models/whisper/test_tokenization_whisper.py +++ b/tests/models/whisper/test_tokenization_whisper.py @@ -20,14 +20,20 @@ from transformers.testing_utils import slow from ...test_tokenization_common import TokenizerTesterMixin -EN_CODE = 50258 -ES_CODE = 50256 +ES_CODE = 50262 +EN_CODE = 50259 +END_OF_TRANSCRIPT = 50257 +START_OF_TRANSCRIPT = 50258 +TRANSLATE = 50358 +TRANSCRIBE = 50359 +NOTIMESTAMPS = 50363 class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = WhisperTokenizer test_rust_tokenizer = False test_sentencepiece = False + test_seq2seq = False def setUp(self): super().setUp() @@ -101,13 +107,6 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): checkpoint_name = "openai/whisper-small.en" - transcript = ( - "'<|startoftranscript|> <|en|> <|transcribe|> <|notimestamps|> Nor is Mr. Quilters manner less interesting" - " than his matter.<|endoftext|>'" - ) - clean_transcript = " Nor is Mr. Quilters manner less interesting than his matter." - french_text = "Bonjour! Il me semble que Mrs Quilters n'était pas présente" - @classmethod def setUpClass(cls): cls.tokenizer: WhisperTokenizer = WhisperTokenizer.from_pretrained(cls.checkpoint_name) @@ -115,15 +114,15 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): def test_tokenizer_equivalence(self): text = "다람쥐 헌 쳇바퀴에 타고파" - multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="ko") - gpt2_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en") + multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="korean") + monolingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en") - gpt2_tokens = gpt2_tokenizer.encode(text) - multilingual_tokens = multilingual_tokenizer.encode(text) + monolingual_tokens = monolingual_tokenizer.encode(text, add_special_tokens=False) + multilingual_tokens = multilingual_tokenizer.encode(text, add_special_tokens=False) - assert gpt2_tokenizer.decode(gpt2_tokens) == text + assert monolingual_tokenizer.decode(monolingual_tokens) == text assert multilingual_tokenizer.decode(multilingual_tokens) == text - assert len(gpt2_tokens) > len(multilingual_tokens) + assert len(monolingual_tokens) > len(multilingual_tokens) # fmt: off EXPECTED_ENG = [ @@ -138,35 +137,42 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): ] # fmt: on - self.assertListEqual(gpt2_tokens, EXPECTED_ENG) + self.assertListEqual(monolingual_tokens, EXPECTED_ENG) self.assertListEqual(multilingual_tokens, EXPECTED_MULTI) def test_tokenizer_special(self): - multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en") - text = "<|startoftranscript|>Hey! How are you feeling? J'ai l'impression que 郷さん est prêt<|endoftext|>" + multilingual_tokenizer = WhisperTokenizer.from_pretrained( + "openai/whisper-tiny", language="english", task="transcribe" + ) + text = "Hey! How are you feeling? J'ai l'impression que 郷さん est prêt" multilingual_tokens = multilingual_tokenizer.encode(text) # fmt: off + # format: <|startoftranscript|> <|lang-id|> <|task|> <|notimestamps|> ... transcription ids ... <|endoftext|> EXPECTED_MULTI = [ - 50257, 10814, 0, 1374, 389, 345, 4203, 30, 449, 6, - 1872, 300, 6, 11011, 2234, 8358, 16268, 225, 115, 43357, - 22174, 1556, 778, 25792, 83, 50256 + START_OF_TRANSCRIPT, EN_CODE, TRANSCRIBE, NOTIMESTAMPS, 7057, 0, 1012, 366, 291, + 2633, 30, 508, 6, 1301, 287, 6, 36107, 631, 220, 11178, + 115, 15567, 871, 44393, END_OF_TRANSCRIPT ] + EXPECTED_SPECIAL_TEXT = ( + "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>Hey! How are you feeling? " + "J'ai l'impression que 郷さん est prêt<|endoftext|>" + ) # fmt: on self.assertListEqual(multilingual_tokens, EXPECTED_MULTI) - self.assertEqual(text, multilingual_tokenizer.decode(multilingual_tokens)) + special_transcript = multilingual_tokenizer.decode(multilingual_tokens, skip_special_tokens=False) + self.assertEqual(special_transcript, EXPECTED_SPECIAL_TEXT) transcript = multilingual_tokenizer.decode(multilingual_tokens, skip_special_tokens=True) - - EXPECTED_JAP = "Hey! How are you feeling? J'ai l'impression que 郷さん est prêt" - self.assertEqual(transcript, EXPECTED_JAP) + self.assertEqual(transcript, text) def test_vocab_size(self): self.assertEqual(self.tokenizer.vocab_size, 50257) + # Copied from transformers.tests.speech_to_test.test_tokenization_speech_to_text.py def test_tokenizer_decode_ignores_language_codes(self): self.assertIn(ES_CODE, self.tokenizer.all_special_ids) generated_ids = [ES_CODE, 4, 1601, 47, 7647, 2] @@ -176,15 +182,48 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): self.assertNotIn(self.tokenizer.eos_token, result) def test_batch_encoding(self): - multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en") - batch = ["<|en|><|notimestamps|>", "<|en|><|notimestamps|>I am sure that"] + multilingual_tokenizer = WhisperTokenizer.from_pretrained( + "openai/whisper-tiny", language="spanish", task="translate" + ) + batch = ["El gato ", "El gato se sentó"] batch_output = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids # fmt: off EXPECTED_MULTI = [ - [50258, 50362, 50256, 50256, 50256, 50256], - [50258, 50362, 40, 716, 1654, 326] + [START_OF_TRANSCRIPT, ES_CODE, TRANSLATE, NOTIMESTAMPS, 17356, 290, 2513, 220, + END_OF_TRANSCRIPT, END_OF_TRANSCRIPT, END_OF_TRANSCRIPT], + [START_OF_TRANSCRIPT, ES_CODE, TRANSLATE, NOTIMESTAMPS, 17356, 290, 2513, 369, + 2279, 812, END_OF_TRANSCRIPT] ] # fmt: on self.assertListEqual(batch_output, EXPECTED_MULTI) + + def test_set_prefix_tokens(self): + multilingual_tokenizer = WhisperTokenizer.from_pretrained( + "openai/whisper-tiny", language="spanish", task="translate" + ) + + # change the language prefix token from Spanish to English + multilingual_tokenizer.set_prefix_tokens(language="english") + + batch = ["the cat", "the cat sat"] + batch_output = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids + + # fmt: off + EXPECTED_MULTI = [ + [START_OF_TRANSCRIPT, EN_CODE, TRANSLATE, NOTIMESTAMPS, 3322, 3857, + END_OF_TRANSCRIPT, END_OF_TRANSCRIPT], + [START_OF_TRANSCRIPT, EN_CODE, TRANSLATE, NOTIMESTAMPS, 3322, 3857, + 3227, END_OF_TRANSCRIPT] + ] + # fmt: on + + self.assertListEqual(batch_output, EXPECTED_MULTI) + + def test_batch_encoding_decoding(self): + multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="spanish") + batch = ["hola güey", "que onda"] + batch_encoding = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids + transcription = multilingual_tokenizer.batch_decode(batch_encoding, skip_special_tokens=True) + self.assertListEqual(batch, transcription)