[Whisper] Allow basic text normalization (#26149)
* [Whisper] Allow basic text normalization * up * style copies
This commit is contained in:
@@ -23,7 +23,7 @@ import regex as re
|
|||||||
|
|
||||||
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .english_normalizer import EnglishTextNormalizer
|
from .english_normalizer import BasicTextNormalizer, EnglishTextNormalizer
|
||||||
|
|
||||||
|
|
||||||
VOCAB_FILES_NAMES = {
|
VOCAB_FILES_NAMES = {
|
||||||
@@ -510,6 +510,15 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
|
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
|
||||||
return normalizer(text)
|
return normalizer(text)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _basic_normalize(text, remove_diacritics=False):
|
||||||
|
"""
|
||||||
|
Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on
|
||||||
|
multilingual text.
|
||||||
|
"""
|
||||||
|
normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics)
|
||||||
|
return normalizer(text)
|
||||||
|
|
||||||
def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str:
|
def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_precision=0.02) -> str:
|
||||||
"""
|
"""
|
||||||
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
|
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
|
||||||
@@ -617,6 +626,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
output_offsets: bool = False,
|
output_offsets: bool = False,
|
||||||
time_precision=0.02,
|
time_precision=0.02,
|
||||||
decode_with_timestamps: bool = False,
|
decode_with_timestamps: bool = False,
|
||||||
|
normalize: bool = False,
|
||||||
|
basic_normalize: bool = False,
|
||||||
|
remove_diacritics: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -633,8 +645,6 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
clean_up_tokenization_spaces (`bool`, *optional*):
|
clean_up_tokenization_spaces (`bool`, *optional*):
|
||||||
Whether or not to clean up the tokenization spaces. If `None`, will default to
|
Whether or not to clean up the tokenization spaces. If `None`, will default to
|
||||||
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
|
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
|
||||||
kwargs (additional keyword arguments, *optional*):
|
|
||||||
Will be passed to the underlying model specific decode method.
|
|
||||||
output_offsets (`bool`, *optional*, defaults to `False`):
|
output_offsets (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
|
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
|
||||||
timestamps.
|
timestamps.
|
||||||
@@ -642,6 +652,17 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
The time ratio to convert from token to time.
|
The time ratio to convert from token to time.
|
||||||
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
|
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to decode with timestamps included in the raw text.
|
Whether or not to decode with timestamps included in the raw text.
|
||||||
|
normalize (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to apply the English text normalizer to the decoded text. Only applicable when the
|
||||||
|
target text is in English. Otherwise, the basic text normalizer should be applied.
|
||||||
|
basic_normalize (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to apply the Basic text normalizer to the decoded text. Applicable to multilingual
|
||||||
|
target text.
|
||||||
|
remove_diacritics (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to remove diacritics when applying the Basic text normalizer. Removing diacritics may
|
||||||
|
destroy information in the decoded text, hence it should be used with caution.
|
||||||
|
kwargs (additional keyword arguments, *optional*):
|
||||||
|
Will be passed to the underlying model specific decode method.
|
||||||
Returns:
|
Returns:
|
||||||
`str`: The decoded sentence.
|
`str`: The decoded sentence.
|
||||||
"""
|
"""
|
||||||
@@ -654,7 +675,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
filtered_ids,
|
filtered_ids,
|
||||||
skip_special_tokens=skip_special_tokens,
|
skip_special_tokens=skip_special_tokens,
|
||||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
decode_with_timestamps=decode_with_timestamps,
|
normalize=normalize,
|
||||||
|
basic_normalize=basic_normalize,
|
||||||
|
remove_diacritics=remove_diacritics,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
if decode_with_timestamps:
|
if decode_with_timestamps:
|
||||||
@@ -676,7 +699,8 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
token_ids: Union[int, List[int]],
|
token_ids: Union[int, List[int]],
|
||||||
skip_special_tokens: bool = False,
|
skip_special_tokens: bool = False,
|
||||||
normalize: bool = False,
|
normalize: bool = False,
|
||||||
decode_with_timestamps: bool = False,
|
basic_normalize: bool = False,
|
||||||
|
remove_diacritics: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
|
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
|
||||||
@@ -705,6 +729,9 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
if normalize:
|
if normalize:
|
||||||
clean_text = self._normalize(text)
|
clean_text = self._normalize(text)
|
||||||
return clean_text
|
return clean_text
|
||||||
|
elif basic_normalize:
|
||||||
|
clean_text = self._basic_normalize(text, remove_diacritics=remove_diacritics)
|
||||||
|
return clean_text
|
||||||
else:
|
else:
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from tokenizers import AddedToken, pre_tokenizers, processors
|
|||||||
from ...tokenization_utils_base import BatchEncoding
|
from ...tokenization_utils_base import BatchEncoding
|
||||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .english_normalizer import EnglishTextNormalizer
|
from .english_normalizer import BasicTextNormalizer, EnglishTextNormalizer
|
||||||
from .tokenization_whisper import LANGUAGES, TASK_IDS, TO_LANGUAGE_CODE, WhisperTokenizer, _decode_asr
|
from .tokenization_whisper import LANGUAGES, TASK_IDS, TO_LANGUAGE_CODE, WhisperTokenizer, _decode_asr
|
||||||
|
|
||||||
|
|
||||||
@@ -331,6 +331,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
output_offsets: bool = False,
|
output_offsets: bool = False,
|
||||||
time_precision=0.02,
|
time_precision=0.02,
|
||||||
decode_with_timestamps: bool = False,
|
decode_with_timestamps: bool = False,
|
||||||
|
normalize: bool = False,
|
||||||
|
basic_normalize: bool = False,
|
||||||
|
remove_diacritics: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -347,8 +350,6 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
clean_up_tokenization_spaces (`bool`, *optional*):
|
clean_up_tokenization_spaces (`bool`, *optional*):
|
||||||
Whether or not to clean up the tokenization spaces. If `None`, will default to
|
Whether or not to clean up the tokenization spaces. If `None`, will default to
|
||||||
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
|
`self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
|
||||||
kwargs (additional keyword arguments, *optional*):
|
|
||||||
Will be passed to the underlying model specific decode method.
|
|
||||||
output_offsets (`bool`, *optional*, defaults to `False`):
|
output_offsets (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
|
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
|
||||||
timestamps.
|
timestamps.
|
||||||
@@ -356,6 +357,17 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
The time ratio to convert from token to time.
|
The time ratio to convert from token to time.
|
||||||
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
|
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to decode with timestamps included in the raw text.
|
Whether or not to decode with timestamps included in the raw text.
|
||||||
|
normalize (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to apply the English text normalizer to the decoded text. Only applicable when the
|
||||||
|
target text is in English. Otherwise, the basic text normalizer should be applied.
|
||||||
|
basic_normalize (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to apply the Basic text normalizer to the decoded text. Applicable to multilingual
|
||||||
|
target text.
|
||||||
|
remove_diacritics (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to remove diacritics when applying the Basic text normalizer. Removing diacritics may
|
||||||
|
destroy information in the decoded text, hence it should be used with caution.
|
||||||
|
kwargs (additional keyword arguments, *optional*):
|
||||||
|
Will be passed to the underlying model specific decode method.
|
||||||
Returns:
|
Returns:
|
||||||
`str`: The decoded sentence.
|
`str`: The decoded sentence.
|
||||||
"""
|
"""
|
||||||
@@ -368,7 +380,9 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
filtered_ids,
|
filtered_ids,
|
||||||
skip_special_tokens=skip_special_tokens,
|
skip_special_tokens=skip_special_tokens,
|
||||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||||
decode_with_timestamps=decode_with_timestamps,
|
normalize=normalize,
|
||||||
|
basic_normalize=basic_normalize,
|
||||||
|
remove_diacritics=remove_diacritics,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
if decode_with_timestamps:
|
if decode_with_timestamps:
|
||||||
@@ -385,12 +399,17 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
return {"text": text, "offsets": offsets}
|
return {"text": text, "offsets": offsets}
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def _decode(self, *args, normalize: bool = False, **kwargs) -> str:
|
def _decode(
|
||||||
|
self, *args, normalize: bool = False, basic_normalize: bool = False, remove_diacritics: bool = False, **kwargs
|
||||||
|
) -> str:
|
||||||
text = super()._decode(*args, **kwargs)
|
text = super()._decode(*args, **kwargs)
|
||||||
|
|
||||||
if normalize:
|
if normalize:
|
||||||
clean_text = self._normalize(text)
|
clean_text = self._normalize(text)
|
||||||
return clean_text
|
return clean_text
|
||||||
|
elif basic_normalize:
|
||||||
|
clean_text = self._basic_normalize(text, remove_diacritics=remove_diacritics)
|
||||||
|
return clean_text
|
||||||
else:
|
else:
|
||||||
return text
|
return text
|
||||||
|
|
||||||
@@ -403,6 +422,16 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
|
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
|
||||||
return normalizer(text)
|
return normalizer(text)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._basic_normalize
|
||||||
|
def _basic_normalize(text, remove_diacritics=False):
|
||||||
|
"""
|
||||||
|
Normalize a given string using the `BasicTextNormalizer` class, which preforms commons transformation on
|
||||||
|
multilingual text.
|
||||||
|
"""
|
||||||
|
normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics)
|
||||||
|
return normalizer(text)
|
||||||
|
|
||||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||||
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
||||||
|
|
||||||
|
|||||||
@@ -273,6 +273,40 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
self.assertEqual(expected_tokens, output_rust[1])
|
self.assertEqual(expected_tokens, output_rust[1])
|
||||||
self.assertEqual(expected_indices, output_rust[2])
|
self.assertEqual(expected_indices, output_rust[2])
|
||||||
|
|
||||||
|
def test_basic_normalizer(self):
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
rust_tokenizer = self.get_rust_tokenizer()
|
||||||
|
|
||||||
|
input_str = "Hola güey!"
|
||||||
|
expected_output_normalize = "hola güey "
|
||||||
|
expected_output_diacritics = "hola guey "
|
||||||
|
|
||||||
|
# tokenizer tests
|
||||||
|
encoded_input = tokenizer(input_str).input_ids
|
||||||
|
decoded_output = tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=False)
|
||||||
|
self.assertEqual(decoded_output, input_str)
|
||||||
|
|
||||||
|
decoded_output_normalize = tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=True)
|
||||||
|
self.assertEqual(decoded_output_normalize, expected_output_normalize)
|
||||||
|
|
||||||
|
decoded_output_diacritics = tokenizer.decode(
|
||||||
|
encoded_input, skip_special_tokens=True, basic_normalize=True, remove_diacritics=True
|
||||||
|
)
|
||||||
|
self.assertEqual(decoded_output_diacritics, expected_output_diacritics)
|
||||||
|
|
||||||
|
# fast tokenizer tests
|
||||||
|
encoded_input = rust_tokenizer(input_str).input_ids
|
||||||
|
decoded_output = rust_tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=False)
|
||||||
|
self.assertEqual(decoded_output, input_str)
|
||||||
|
|
||||||
|
decoded_output_normalize = rust_tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=True)
|
||||||
|
self.assertEqual(decoded_output_normalize, expected_output_normalize)
|
||||||
|
|
||||||
|
decoded_output_diacritics = rust_tokenizer.decode(
|
||||||
|
encoded_input, skip_special_tokens=True, basic_normalize=True, remove_diacritics=True
|
||||||
|
)
|
||||||
|
self.assertEqual(decoded_output_diacritics, expected_output_diacritics)
|
||||||
|
|
||||||
|
|
||||||
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||||
checkpoint_name = "openai/whisper-small.en"
|
checkpoint_name = "openai/whisper-small.en"
|
||||||
|
|||||||
Reference in New Issue
Block a user