From da3c79b245afcce88f5db79ada10bf5b7c200ab1 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Mon, 29 Jan 2024 16:07:35 +0000 Subject: [PATCH] [Whisper] Make tokenizer normalization public (#28136) * [Whisper] Make tokenizer normalization public * add to docs --- docs/source/en/model_doc/whisper.md | 4 ++++ .../models/whisper/tokenization_whisper.py | 21 ++++++++++++++++--- .../whisper/tokenization_whisper_fast.py | 21 +++++++++++++++++-- 3 files changed, 41 insertions(+), 5 deletions(-) diff --git a/docs/source/en/model_doc/whisper.md b/docs/source/en/model_doc/whisper.md index 37411209bf..e384d2be90 100644 --- a/docs/source/en/model_doc/whisper.md +++ b/docs/source/en/model_doc/whisper.md @@ -102,6 +102,8 @@ python convert_hf_to_openai.py \ - save_vocabulary - batch_decode - decode + - basic_normalize + - normalize ## WhisperTokenizerFast @@ -113,6 +115,8 @@ python convert_hf_to_openai.py \ - save_vocabulary - batch_decode - decode + - basic_normalize + - normalize ## WhisperFeatureExtractor diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 127f5be619..f853c60e26 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -15,6 +15,7 @@ """Tokenization classes for Whisper.""" import json import os +import warnings from functools import lru_cache from typing import List, Optional, Tuple, Union @@ -507,6 +508,20 @@ class WhisperTokenizer(PreTrainedTokenizer): return self.decoder.get(index, "") def _normalize(self, text): + warnings.warn( + "The private method `_normalize` is deprecated and will be removed in v5 of Transformers." + "You can normalize an input string using the Whisper English normalizer using the `normalize` method." + ) + return self.normalize(text) + + def _basic_normalize(self, text, remove_diacritics=False): + warnings.warn( + "The private method `_basic_normalize` is deprecated and will be removed in v5 of Transformers." + "You can normalize an input string using the Whisper basic normalizer using the `basic_normalize` method." + ) + return self.basic_normalize(text, remove_diacritics=remove_diacritics) + + def normalize(self, text): """ Normalize a given string using the `EnglishTextNormalizer` class, which preforms commons transformation on english text. @@ -515,7 +530,7 @@ class WhisperTokenizer(PreTrainedTokenizer): return normalizer(text) @staticmethod - def _basic_normalize(text, remove_diacritics=False): + def basic_normalize(text, remove_diacritics=False): """ Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on multilingual text. @@ -745,10 +760,10 @@ class WhisperTokenizer(PreTrainedTokenizer): text = "".join(sub_texts) if normalize: - clean_text = self._normalize(text) + clean_text = self.normalize(text) return clean_text elif basic_normalize: - clean_text = self._basic_normalize(text, remove_diacritics=remove_diacritics) + clean_text = self.basic_normalize(text, remove_diacritics=remove_diacritics) return clean_text else: return text diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index 509175be99..dc5a3e0dc1 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -16,6 +16,7 @@ import json import os import re +import warnings from functools import lru_cache from typing import List, Optional, Tuple @@ -427,6 +428,22 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._normalize def _normalize(self, text): + warnings.warn( + "The private method `_normalize` is deprecated and will be removed in v5 of Transformers." + "You can normalize an input string using the Whisper English normalizer using the `normalize` method." + ) + return self.normalize(text) + + # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._basic_normalize + def _basic_normalize(self, text, remove_diacritics=False): + warnings.warn( + "The private method `_basic_normalize` is deprecated and will be removed in v5 of Transformers." + "You can normalize an input string using the Whisper basic normalizer using the `basic_normalize` method." + ) + return self.basic_normalize(text, remove_diacritics=remove_diacritics) + + # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.normalize + def normalize(self, text): """ Normalize a given string using the `EnglishTextNormalizer` class, which preforms commons transformation on english text. @@ -435,8 +452,8 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): return normalizer(text) @staticmethod - # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._basic_normalize - def _basic_normalize(text, remove_diacritics=False): + # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.basic_normalize + def basic_normalize(text, remove_diacritics=False): """ Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on multilingual text.