fix prompt strip to support tensors and np arrays (#27818)
* fix prompt strip to support tensors and np arrays * framework agnostic * change logic check before converting prompt into list Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * adding _convert_to_list to tokenization_whisper_fast * adding tests for prompt decoding * adding comment Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * adding comment Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * revert minor * make style formatting * style formatting after update * Update src/transformers/models/whisper/tokenization_whisper_fast.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * fixing _strip_prompt to handle _decode_with_timestamps * fix copies --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
@@ -851,9 +851,16 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
batch_encoding.convert_to_tensors(tensor_type=return_tensors)
|
batch_encoding.convert_to_tensors(tensor_type=return_tensors)
|
||||||
return batch_encoding["input_ids"]
|
return batch_encoding["input_ids"]
|
||||||
|
|
||||||
@staticmethod
|
def _strip_prompt(self, token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int):
|
||||||
def _strip_prompt(token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int):
|
if not isinstance(token_ids, list):
|
||||||
has_prompt = isinstance(token_ids, list) and token_ids and token_ids[0] == prompt_token_id
|
token_ids = self._convert_to_list(token_ids)
|
||||||
|
|
||||||
|
# handle case of empty token_ids for decoding with timestamps.
|
||||||
|
# at this point token_ids is a list, so it is safe to use if not check.
|
||||||
|
if not token_ids:
|
||||||
|
return token_ids
|
||||||
|
|
||||||
|
has_prompt = token_ids[0] == prompt_token_id
|
||||||
if has_prompt:
|
if has_prompt:
|
||||||
if decoder_start_token_id in token_ids:
|
if decoder_start_token_id in token_ids:
|
||||||
return token_ids[token_ids.index(decoder_start_token_id) :]
|
return token_ids[token_ids.index(decoder_start_token_id) :]
|
||||||
@@ -862,6 +869,16 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_to_list(token_ids):
|
||||||
|
# convert type to ndarray if necessary
|
||||||
|
if "torch" in str(type(token_ids)) or "tensorflow" in str(type(token_ids)) and hasattr(token_ids, "numpy"):
|
||||||
|
token_ids = token_ids.numpy()
|
||||||
|
# now the token ids are either a numpy array, or a list of lists
|
||||||
|
if isinstance(token_ids, np.ndarray):
|
||||||
|
token_ids = token_ids.tolist()
|
||||||
|
return token_ids
|
||||||
|
|
||||||
|
|
||||||
def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision):
|
def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -582,10 +582,17 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
batch_encoding.convert_to_tensors(tensor_type=return_tensors)
|
batch_encoding.convert_to_tensors(tensor_type=return_tensors)
|
||||||
return batch_encoding["input_ids"]
|
return batch_encoding["input_ids"]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._strip_prompt
|
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._strip_prompt
|
||||||
def _strip_prompt(token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int):
|
def _strip_prompt(self, token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int):
|
||||||
has_prompt = isinstance(token_ids, list) and token_ids and token_ids[0] == prompt_token_id
|
if not isinstance(token_ids, list):
|
||||||
|
token_ids = self._convert_to_list(token_ids)
|
||||||
|
|
||||||
|
# handle case of empty token_ids for decoding with timestamps.
|
||||||
|
# at this point token_ids is a list, so it is safe to use if not check.
|
||||||
|
if not token_ids:
|
||||||
|
return token_ids
|
||||||
|
|
||||||
|
has_prompt = token_ids[0] == prompt_token_id
|
||||||
if has_prompt:
|
if has_prompt:
|
||||||
if decoder_start_token_id in token_ids:
|
if decoder_start_token_id in token_ids:
|
||||||
return token_ids[token_ids.index(decoder_start_token_id) :]
|
return token_ids[token_ids.index(decoder_start_token_id) :]
|
||||||
@@ -593,3 +600,14 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._convert_to_list
|
||||||
|
def _convert_to_list(token_ids):
|
||||||
|
# convert type to ndarray if necessary
|
||||||
|
if "torch" in str(type(token_ids)) or "tensorflow" in str(type(token_ids)) and hasattr(token_ids, "numpy"):
|
||||||
|
token_ids = token_ids.numpy()
|
||||||
|
# now the token ids are either a numpy array, or a list of lists
|
||||||
|
if isinstance(token_ids, np.ndarray):
|
||||||
|
token_ids = token_ids.tolist()
|
||||||
|
return token_ids
|
||||||
|
|||||||
@@ -14,6 +14,8 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
|
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
|
||||||
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
|
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
|
||||||
from transformers.testing_utils import slow
|
from transformers.testing_utils import slow
|
||||||
@@ -251,6 +253,39 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.assertListEqual(tokenizer_prompt_ids.tolist(), fast_tokenizer_prompt_ids.tolist())
|
self.assertListEqual(tokenizer_prompt_ids.tolist(), fast_tokenizer_prompt_ids.tolist())
|
||||||
|
|
||||||
|
def test_tokenizer_decode_prompt(self):
|
||||||
|
prompt_text = "What does the fox say?"
|
||||||
|
input_text = "Hatee hatee hatee ho"
|
||||||
|
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
rust_tokenizer = self.get_rust_tokenizer()
|
||||||
|
|
||||||
|
# encode prompt and input text using tokenizer
|
||||||
|
prompt_ids = tokenizer.get_prompt_ids(prompt_text, return_tensors="np")
|
||||||
|
input_ids = tokenizer(input_text, return_tensors="np").input_ids[0]
|
||||||
|
input_ids = np.hstack([prompt_ids, input_ids])
|
||||||
|
|
||||||
|
# encode using fast tokenizer
|
||||||
|
rust_prompt_ids = rust_tokenizer.get_prompt_ids(prompt_text, return_tensors="np")
|
||||||
|
rust_input_ids = rust_tokenizer(input_text, return_tensors="np").input_ids[0]
|
||||||
|
rust_input_ids = np.hstack([rust_prompt_ids, rust_input_ids])
|
||||||
|
|
||||||
|
# check with prompt in output
|
||||||
|
pred_text = tokenizer.decode(input_ids, skip_special_tokens=False)
|
||||||
|
rust_pred_text = rust_tokenizer.decode(rust_input_ids, skip_special_tokens=False)
|
||||||
|
|
||||||
|
# check correctness for both tokenizers
|
||||||
|
expected_text = f"<|startofprev|> {prompt_text}<|startoftranscript|><|notimestamps|>{input_text}<|endoftext|>"
|
||||||
|
self.assertEqual(pred_text.strip(), expected_text)
|
||||||
|
self.assertEqual(rust_pred_text.strip(), expected_text)
|
||||||
|
|
||||||
|
# check stripping prompt from output
|
||||||
|
pred_text = tokenizer.decode(input_ids, skip_special_tokens=True)
|
||||||
|
rust_pred_text = tokenizer.decode(input_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
self.assertEqual(pred_text.strip(), input_text)
|
||||||
|
self.assertEqual(rust_pred_text.strip(), input_text)
|
||||||
|
|
||||||
def test_combine_tokens_into_words(self):
|
def test_combine_tokens_into_words(self):
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
rust_tokenizer = self.get_rust_tokenizer()
|
rust_tokenizer = self.get_rust_tokenizer()
|
||||||
|
|||||||
Reference in New Issue
Block a user