[Whisper] Allow basic text normalization (#26149)
* [Whisper] Allow basic text normalization * up * style copies
This commit is contained in:
@@ -23,7 +23,7 @@ import regex as re
|
||||
|
||||
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||
from ...utils import logging
|
||||
from .english_normalizer import EnglishTextNormalizer
|
||||
from .english_normalizer import BasicTextNormalizer, EnglishTextNormalizer
|
||||
|
||||
|
||||
VOCAB_FILES_NAMES = {
|
||||
@@ -510,6 +510,15 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
|
||||
return normalizer(text)
|
||||
|
||||
@staticmethod
|
||||
def _basic_normalize(text, remove_diacritics=False):
|
||||
"""
|
||||
Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on
|
||||
multilingual text.
|
||||
"""
|
||||
normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics)
|
||||
return normalizer(text)
|
||||
|
||||
def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str:
|
||||
"""
|
||||
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
|
||||
@@ -617,6 +626,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
output_offsets: bool = False,
|
||||
time_precision=0.02,
|
||||
decode_with_timestamps: bool = False,
|
||||
normalize: bool = False,
|
||||
basic_normalize: bool = False,
|
||||
remove_diacritics: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -633,8 +645,6 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
clean_up_tokenization_spaces (`bool`, *optional*):
|
||||
Whether or not to clean up the tokenization spaces. If `None`, will default to
|
||||
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Will be passed to the underlying model specific decode method.
|
||||
output_offsets (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
|
||||
timestamps.
|
||||
@@ -642,6 +652,17 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
The time ratio to convert from token to time.
|
||||
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to decode with timestamps included in the raw text.
|
||||
normalize (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to apply the English text normalizer to the decoded text. Only applicable when the
|
||||
target text is in English. Otherwise, the basic text normalizer should be applied.
|
||||
basic_normalize (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to apply the Basic text normalizer to the decoded text. Applicable to multilingual
|
||||
target text.
|
||||
remove_diacritics (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to remove diacritics when applying the Basic text normalizer. Removing diacritics may
|
||||
destroy information in the decoded text, hence it should be used with caution.
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Will be passed to the underlying model specific decode method.
|
||||
Returns:
|
||||
`str`: The decoded sentence.
|
||||
"""
|
||||
@@ -654,7 +675,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
filtered_ids,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
decode_with_timestamps=decode_with_timestamps,
|
||||
normalize=normalize,
|
||||
basic_normalize=basic_normalize,
|
||||
remove_diacritics=remove_diacritics,
|
||||
**kwargs,
|
||||
)
|
||||
if decode_with_timestamps:
|
||||
@@ -676,7 +699,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
token_ids: Union[int, List[int]],
|
||||
skip_special_tokens: bool = False,
|
||||
normalize: bool = False,
|
||||
decode_with_timestamps: bool = False,
|
||||
basic_normalize: bool = False,
|
||||
remove_diacritics: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
|
||||
@@ -705,6 +729,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
||||
if normalize:
|
||||
clean_text = self._normalize(text)
|
||||
return clean_text
|
||||
elif basic_normalize:
|
||||
clean_text = self._basic_normalize(text, remove_diacritics=remove_diacritics)
|
||||
return clean_text
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ from tokenizers import AddedToken, pre_tokenizers, processors
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ...utils import logging
|
||||
from .english_normalizer import EnglishTextNormalizer
|
||||
from .english_normalizer import BasicTextNormalizer, EnglishTextNormalizer
|
||||
from .tokenization_whisper import LANGUAGES, TASK_IDS, TO_LANGUAGE_CODE, WhisperTokenizer, _decode_asr
|
||||
|
||||
|
||||
@@ -331,6 +331,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
||||
output_offsets: bool = False,
|
||||
time_precision=0.02,
|
||||
decode_with_timestamps: bool = False,
|
||||
normalize: bool = False,
|
||||
basic_normalize: bool = False,
|
||||
remove_diacritics: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -347,8 +350,6 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
||||
clean_up_tokenization_spaces (`bool`, *optional*):
|
||||
Whether or not to clean up the tokenization spaces. If `None`, will default to
|
||||
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Will be passed to the underlying model specific decode method.
|
||||
output_offsets (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
|
||||
timestamps.
|
||||
@@ -356,6 +357,17 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
||||
The time ratio to convert from token to time.
|
||||
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to decode with timestamps included in the raw text.
|
||||
normalize (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to apply the English text normalizer to the decoded text. Only applicable when the
|
||||
target text is in English. Otherwise, the basic text normalizer should be applied.
|
||||
basic_normalize (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to apply the Basic text normalizer to the decoded text. Applicable to multilingual
|
||||
target text.
|
||||
remove_diacritics (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to remove diacritics when applying the Basic text normalizer. Removing diacritics may
|
||||
destroy information in the decoded text, hence it should be used with caution.
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Will be passed to the underlying model specific decode method.
|
||||
Returns:
|
||||
`str`: The decoded sentence.
|
||||
"""
|
||||
@@ -368,7 +380,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
||||
filtered_ids,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
decode_with_timestamps=decode_with_timestamps,
|
||||
normalize=normalize,
|
||||
basic_normalize=basic_normalize,
|
||||
remove_diacritics=remove_diacritics,
|
||||
**kwargs,
|
||||
)
|
||||
if decode_with_timestamps:
|
||||
@@ -385,12 +399,17 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
||||
return {"text": text, "offsets": offsets}
|
||||
return text
|
||||
|
||||
def _decode(self, *args, normalize: bool = False, **kwargs) -> str:
|
||||
def _decode(
|
||||
self, *args, normalize: bool = False, basic_normalize: bool = False, remove_diacritics: bool = False, **kwargs
|
||||
) -> str:
|
||||
text = super()._decode(*args, **kwargs)
|
||||
|
||||
if normalize:
|
||||
clean_text = self._normalize(text)
|
||||
return clean_text
|
||||
elif basic_normalize:
|
||||
clean_text = self._basic_normalize(text, remove_diacritics=remove_diacritics)
|
||||
return clean_text
|
||||
else:
|
||||
return text
|
||||
|
||||
@@ -403,6 +422,16 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
||||
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
|
||||
return normalizer(text)
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._basic_normalize
|
||||
def _basic_normalize(text, remove_diacritics=False):
|
||||
"""
|
||||
Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on
|
||||
multilingual text.
|
||||
"""
|
||||
normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics)
|
||||
return normalizer(text)
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
||||
|
||||
|
||||
@@ -273,6 +273,40 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(expected_tokens, output_rust[1])
|
||||
self.assertEqual(expected_indices, output_rust[2])
|
||||
|
||||
def test_basic_normalizer(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
rust_tokenizer = self.get_rust_tokenizer()
|
||||
|
||||
input_str = "Hola güey!"
|
||||
expected_output_normalize = "hola güey "
|
||||
expected_output_diacritics = "hola guey "
|
||||
|
||||
# tokenizer tests
|
||||
encoded_input = tokenizer(input_str).input_ids
|
||||
decoded_output = tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=False)
|
||||
self.assertEqual(decoded_output, input_str)
|
||||
|
||||
decoded_output_normalize = tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=True)
|
||||
self.assertEqual(decoded_output_normalize, expected_output_normalize)
|
||||
|
||||
decoded_output_diacritics = tokenizer.decode(
|
||||
encoded_input, skip_special_tokens=True, basic_normalize=True, remove_diacritics=True
|
||||
)
|
||||
self.assertEqual(decoded_output_diacritics, expected_output_diacritics)
|
||||
|
||||
# fast tokenizer tests
|
||||
encoded_input = rust_tokenizer(input_str).input_ids
|
||||
decoded_output = rust_tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=False)
|
||||
self.assertEqual(decoded_output, input_str)
|
||||
|
||||
decoded_output_normalize = rust_tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=True)
|
||||
self.assertEqual(decoded_output_normalize, expected_output_normalize)
|
||||
|
||||
decoded_output_diacritics = rust_tokenizer.decode(
|
||||
encoded_input, skip_special_tokens=True, basic_normalize=True, remove_diacritics=True
|
||||
)
|
||||
self.assertEqual(decoded_output_diacritics, expected_output_diacritics)
|
||||
|
||||
|
||||
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||
checkpoint_name = "openai/whisper-small.en"
|
||||
|
||||
Reference in New Issue
Block a user