[Whisper] Allow basic text normalization (#26149)
* [Whisper] Allow basic text normalization * up * style copies
This commit is contained in:
@@ -273,6 +273,40 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(expected_tokens, output_rust[1])
|
||||
self.assertEqual(expected_indices, output_rust[2])
|
||||
|
||||
def test_basic_normalizer(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
rust_tokenizer = self.get_rust_tokenizer()
|
||||
|
||||
input_str = "Hola güey!"
|
||||
expected_output_normalize = "hola güey "
|
||||
expected_output_diacritics = "hola guey "
|
||||
|
||||
# tokenizer tests
|
||||
encoded_input = tokenizer(input_str).input_ids
|
||||
decoded_output = tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=False)
|
||||
self.assertEqual(decoded_output, input_str)
|
||||
|
||||
decoded_output_normalize = tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=True)
|
||||
self.assertEqual(decoded_output_normalize, expected_output_normalize)
|
||||
|
||||
decoded_output_diacritics = tokenizer.decode(
|
||||
encoded_input, skip_special_tokens=True, basic_normalize=True, remove_diacritics=True
|
||||
)
|
||||
self.assertEqual(decoded_output_diacritics, expected_output_diacritics)
|
||||
|
||||
# fast tokenizer tests
|
||||
encoded_input = rust_tokenizer(input_str).input_ids
|
||||
decoded_output = rust_tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=False)
|
||||
self.assertEqual(decoded_output, input_str)
|
||||
|
||||
decoded_output_normalize = rust_tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=True)
|
||||
self.assertEqual(decoded_output_normalize, expected_output_normalize)
|
||||
|
||||
decoded_output_diacritics = rust_tokenizer.decode(
|
||||
encoded_input, skip_special_tokens=True, basic_normalize=True, remove_diacritics=True
|
||||
)
|
||||
self.assertEqual(decoded_output_diacritics, expected_output_diacritics)
|
||||
|
||||
|
||||
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||
checkpoint_name = "openai/whisper-small.en"
|
||||
|
||||
Reference in New Issue
Block a user