From 8ed635258cee5b29256f3c7c4a3f4a254d8743b2 Mon Sep 17 00:00:00 2001 From: Hannan Komari Date: Thu, 12 Sep 2024 14:51:59 +0330 Subject: [PATCH] Fix flax whisper tokenizer bug (#33151) * Update tokenization_whisper.py Fix issue with flax whisper model * Update tokenization_whisper_fast.py Fix issue with flax whisper model * Update tokenization_whisper.py just check len of token_ids * Update tokenization_whisper_fast.py just use len of token_ids * Update tokenization_whisper_fast.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list * Update tokenization_whisper.py and revert changes in _strip_prompt and add support to jax arrays in _convert_to_list * Update test_tokenization_whisper.py to add test for _convert_to_list method * Update test_tokenization_whisper.py to fix code style issues * Fix code style * Fix code check again * Update test_tokenization)whisper.py to Improve code style * Update test_tokenization_whisper.py to run each of jax, tf and flax modules if available * Update tests/models/whisper/test_tokenization_whisper.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update test_tokenization_whisper.py and use require_xxx decorators instead of `is_xxx_available()` method * Revert the changes automatically applied by formatter and was unrelated to PR * Format for minimal changes --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../models/whisper/tokenization_whisper.py | 2 + .../whisper/tokenization_whisper_fast.py | 2 + .../whisper/test_tokenization_whisper.py | 41 ++++++++++++++++++- 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 823a11c3ec..0a6eb75c55 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -880,6 +880,8 @@ class WhisperTokenizer(PreTrainedTokenizer): token_ids = token_ids.cpu().numpy() elif "tensorflow" in str(type(token_ids)): token_ids = token_ids.numpy() + elif "jaxlib" in str(type(token_ids)): + token_ids = token_ids.tolist() # now the token ids are either a numpy array, or a list of lists if isinstance(token_ids, np.ndarray): token_ids = token_ids.tolist() diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index 11c2b46567..66cf412cc2 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -613,6 +613,8 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): token_ids = token_ids.cpu().numpy() elif "tensorflow" in str(type(token_ids)): token_ids = token_ids.numpy() + elif "jaxlib" in str(type(token_ids)): + token_ids = token_ids.tolist() # now the token ids are either a numpy array, or a list of lists if isinstance(token_ids, np.ndarray): token_ids = token_ids.tolist() diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py index 5c653f1984..27b24448d5 100644 --- a/tests/models/whisper/test_tokenization_whisper.py +++ b/tests/models/whisper/test_tokenization_whisper.py @@ -18,7 +18,7 @@ import numpy as np from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence -from transformers.testing_utils import slow +from transformers.testing_utils import require_flax, require_tf, require_torch, slow from ...test_tokenization_common import TokenizerTesterMixin @@ -574,3 +574,42 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"] self.assertEqual(output, []) + + def test_convert_to_list_np(self): + test_list = [[1, 2, 3], [4, 5, 6]] + + # Test with an already converted list + self.assertListEqual(WhisperTokenizer._convert_to_list(test_list), test_list) + self.assertListEqual(WhisperTokenizerFast._convert_to_list(test_list), test_list) + + # Test with a numpy array + np_array = np.array(test_list) + self.assertListEqual(WhisperTokenizer._convert_to_list(np_array), test_list) + self.assertListEqual(WhisperTokenizerFast._convert_to_list(np_array), test_list) + + @require_tf + def test_convert_to_list_tf(self): + import tensorflow as tf + + test_list = [[1, 2, 3], [4, 5, 6]] + tf_tensor = tf.constant(test_list) + self.assertListEqual(WhisperTokenizer._convert_to_list(tf_tensor), test_list) + self.assertListEqual(WhisperTokenizerFast._convert_to_list(tf_tensor), test_list) + + @require_flax + def test_convert_to_list_jax(self): + import jax.numpy as jnp + + test_list = [[1, 2, 3], [4, 5, 6]] + jax_array = jnp.array(test_list) + self.assertListEqual(WhisperTokenizer._convert_to_list(jax_array), test_list) + self.assertListEqual(WhisperTokenizerFast._convert_to_list(jax_array), test_list) + + @require_torch + def test_convert_to_list_pt(self): + import torch + + test_list = [[1, 2, 3], [4, 5, 6]] + torch_tensor = torch.tensor(test_list) + self.assertListEqual(WhisperTokenizer._convert_to_list(torch_tensor), test_list) + self.assertListEqual(WhisperTokenizerFast._convert_to_list(torch_tensor), test_list)