[Whisper] Allow basic text normalization (#26149)

* [Whisper] Allow basic text normalization

* up

* style copies
This commit is contained in:
Sanchit Gandhi
2023-10-03 17:57:16 +01:00
committed by GitHub
parent bd6205919a
commit 57f44dc428
3 changed files with 100 additions and 10 deletions

View File

@@ -273,6 +273,40 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual(expected_tokens, output_rust[1])
self.assertEqual(expected_indices, output_rust[2])
def test_basic_normalizer(self):
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()
input_str = "Hola güey!"
expected_output_normalize = "hola güey "
expected_output_diacritics = "hola guey "
# tokenizer tests
encoded_input = tokenizer(input_str).input_ids
decoded_output = tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=False)
self.assertEqual(decoded_output, input_str)
decoded_output_normalize = tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=True)
self.assertEqual(decoded_output_normalize, expected_output_normalize)
decoded_output_diacritics = tokenizer.decode(
encoded_input, skip_special_tokens=True, basic_normalize=True, remove_diacritics=True
)
self.assertEqual(decoded_output_diacritics, expected_output_diacritics)
# fast tokenizer tests
encoded_input = rust_tokenizer(input_str).input_ids
decoded_output = rust_tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=False)
self.assertEqual(decoded_output, input_str)
decoded_output_normalize = rust_tokenizer.decode(encoded_input, skip_special_tokens=True, basic_normalize=True)
self.assertEqual(decoded_output_normalize, expected_output_normalize)
decoded_output_diacritics = rust_tokenizer.decode(
encoded_input, skip_special_tokens=True, basic_normalize=True, remove_diacritics=True
)
self.assertEqual(decoded_output_diacritics, expected_output_diacritics)
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
checkpoint_name = "openai/whisper-small.en"