From 57f44dc4288a3521bd700405ad41e90a4687abc0 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Tue, 3 Oct 2023 17:57:16 +0100 Subject: [PATCH] [Whisper] Allow basic text normalization (#26149) * [Whisper] Allow basic text normalization * up * style copies --- .../models/whisper/tokenization_whisper.py | 37 +++++++++++++++--- .../whisper/tokenization_whisper_fast.py | 39 ++++++++++++++++--- .../whisper/test_tokenization_whisper.py | 34 ++++++++++++++++ 3 files changed, 100 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index a6f13d4f38..2e0aadab00 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -23,7 +23,7 @@ import regex as re from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging -from .english_normalizer import EnglishTextNormalizer +from .english_normalizer import BasicTextNormalizer, EnglishTextNormalizer VOCAB_FILES_NAMES = { @@ -510,6 +510,15 @@ class WhisperTokenizer(PreTrainedTokenizer): normalizer = EnglishTextNormalizer(self.english_spelling_normalizer) return normalizer(text) + @staticmethod + def _basic_normalize(text, remove_diacritics=False): + """ + Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on + multilingual text. + """ + normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics) + return normalizer(text) + def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str: """ Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes @@ -617,6 +626,9 @@ class WhisperTokenizer(PreTrainedTokenizer): output_offsets: bool = False, time_precision=0.02, decode_with_timestamps: bool = False, + normalize: bool = False, + basic_normalize: bool = False, + remove_diacritics: bool = False, **kwargs, ) -> str: """ @@ -633,8 +645,6 @@ class WhisperTokenizer(PreTrainedTokenizer): clean_up_tokenization_spaces (`bool`, *optional*): Whether or not to clean up the tokenization spaces. If `None`, will default to `self.clean_up_tokenization_spaces` (available in the `tokenizer_config`). - kwargs (additional keyword arguments, *optional*): - Will be passed to the underlying model specific decode method. output_offsets (`bool`, *optional*, defaults to `False`): Whether or not to output the offsets of the tokens. This should only be set if the model predicted timestamps. @@ -642,6 +652,17 @@ class WhisperTokenizer(PreTrainedTokenizer): The time ratio to convert from token to time. decode_with_timestamps (`bool`, *optional*, defaults to `False`): Whether or not to decode with timestamps included in the raw text. + normalize (`bool`, *optional*, defaults to `False`): + Whether or not to apply the English text normalizer to the decoded text. Only applicable when the + target text is in English. Otherwise, the basic text normalizer should be applied. + basic_normalize (`bool`, *optional*, defaults to `False`): + Whether or not to apply the Basic text normalizer to the decoded text. Applicable to multilingual + target text. + remove_diacritics (`bool`, *optional*, defaults to `False`): + Whether or not to remove diacritics when applying the Basic text normalizer. Removing diacritics may + destroy information in the decoded text, hence it should be used with caution. + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific decode method. Returns: `str`: The decoded sentence. """ @@ -654,7 +675,9 @@ class WhisperTokenizer(PreTrainedTokenizer): filtered_ids, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces, - decode_with_timestamps=decode_with_timestamps, + normalize=normalize, + basic_normalize=basic_normalize, + remove_diacritics=remove_diacritics, **kwargs, ) if decode_with_timestamps: @@ -676,7 +699,8 @@ class WhisperTokenizer(PreTrainedTokenizer): token_ids: Union[int, List[int]], skip_special_tokens: bool = False, normalize: bool = False, - decode_with_timestamps: bool = False, + basic_normalize: bool = False, + remove_diacritics: bool = False, **kwargs, ) -> str: self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) @@ -705,6 +729,9 @@ class WhisperTokenizer(PreTrainedTokenizer): if normalize: clean_text = self._normalize(text) return clean_text + elif basic_normalize: + clean_text = self._basic_normalize(text, remove_diacritics=remove_diacritics) + return clean_text else: return text diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index 71b741be52..64a4343a19 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -25,7 +25,7 @@ from tokenizers import AddedToken, pre_tokenizers, processors from ...tokenization_utils_base import BatchEncoding from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...utils import logging -from .english_normalizer import EnglishTextNormalizer +from .english_normalizer import BasicTextNormalizer, EnglishTextNormalizer from .tokenization_whisper import LANGUAGES, TASK_IDS, TO_LANGUAGE_CODE, WhisperTokenizer, _decode_asr @@ -331,6 +331,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): output_offsets: bool = False, time_precision=0.02, decode_with_timestamps: bool = False, + normalize: bool = False, + basic_normalize: bool = False, + remove_diacritics: bool = False, **kwargs, ) -> str: """ @@ -347,8 +350,6 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): clean_up_tokenization_spaces (`bool`, *optional*): Whether or not to clean up the tokenization spaces. If `None`, will default to `self.clean_up_tokenization_spaces` (available in the `tokenizer_config`). - kwargs (additional keyword arguments, *optional*): - Will be passed to the underlying model specific decode method. output_offsets (`bool`, *optional*, defaults to `False`): Whether or not to output the offsets of the tokens. This should only be set if the model predicted timestamps. @@ -356,6 +357,17 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): The time ratio to convert from token to time. decode_with_timestamps (`bool`, *optional*, defaults to `False`): Whether or not to decode with timestamps included in the raw text. + normalize (`bool`, *optional*, defaults to `False`): + Whether or not to apply the English text normalizer to the decoded text. Only applicable when the + target text is in English. Otherwise, the basic text normalizer should be applied. + basic_normalize (`bool`, *optional*, defaults to `False`): + Whether or not to apply the Basic text normalizer to the decoded text. Applicable to multilingual + target text. + remove_diacritics (`bool`, *optional*, defaults to `False`): + Whether or not to remove diacritics when applying the Basic text normalizer. Removing diacritics may + destroy information in the decoded text, hence it should be used with caution. + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific decode method. Returns: `str`: The decoded sentence. """ @@ -368,7 +380,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): filtered_ids, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces, - decode_with_timestamps=decode_with_timestamps, + normalize=normalize, + basic_normalize=basic_normalize, + remove_diacritics=remove_diacritics, **kwargs, ) if decode_with_timestamps: @@ -385,12 +399,17 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): return {"text": text, "offsets": offsets} return text - def _decode(self, *args, normalize: bool = False, **kwargs) -> str: + def _decode( + self, *args, normalize: bool = False, basic_normalize: bool = False, remove_diacritics: bool = False, **kwargs + ) -> str: text = super()._decode(*args, **kwargs) if normalize: clean_text = self._normalize(text) return clean_text + elif basic_normalize: + clean_text = self._basic_normalize(text, remove_diacritics=remove_diacritics) + return clean_text else: return text @@ -403,6 +422,16 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): normalizer = EnglishTextNormalizer(self.english_spelling_normalizer) return normalizer(text) + @staticmethod + # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._basic_normalize + def _basic_normalize(text, remove_diacritics=False): + """ + Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on + multilingual text. + """ + normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics) + return normalizer(text) + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: files = self._tokenizer.model.save(save_directory, name=filename_prefix) diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py index ef58768e22..be9e11de54 100644 --- a/tests/models/whisper/test_tokenization_whisper.py +++ b/tests/models/whisper/test_tokenization_whisper.py @@ -273,6 +273,40 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): self.assertEqual(expected_tokens, output_rust[1]) self.assertEqual(expected_indices, output_rust[2]) + def test_basic_normalizer(self): + tokenizer = self.get_tokenizer() + rust_tokenizer = self.get_rust_tokenizer() + + input_str = "Hola güey!" + expected_output_normalize = "hola güey " + expected_output_diacritics = "hola guey " + + # tokenizer tests + encoded_input = tokenizer(input_str).input_ids + decoded_output = tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=False) + self.assertEqual(decoded_output, input_str) + + decoded_output_normalize = tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=True) + self.assertEqual(decoded_output_normalize, expected_output_normalize) + + decoded_output_diacritics = tokenizer.decode( + encoded_input, skip_special_tokens=True, basic_normalize=True, remove_diacritics=True + ) + self.assertEqual(decoded_output_diacritics, expected_output_diacritics) + + # fast tokenizer tests + encoded_input = rust_tokenizer(input_str).input_ids + decoded_output = rust_tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=False) + self.assertEqual(decoded_output, input_str) + + decoded_output_normalize = rust_tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=True) + self.assertEqual(decoded_output_normalize, expected_output_normalize) + + decoded_output_diacritics = rust_tokenizer.decode( + encoded_input, skip_special_tokens=True, basic_normalize=True, remove_diacritics=True + ) + self.assertEqual(decoded_output_diacritics, expected_output_diacritics) + class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): checkpoint_name = "openai/whisper-small.en"