fix prompt strip to support tensors and np arrays (#27818)

* fix prompt strip to support tensors and np arrays

* framework agnostic

* change logic check before converting prompt into list

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* adding _convert_to_list to tokenization_whisper_fast

* adding tests for prompt decoding

* adding comment

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* adding comment

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* revert minor

* make style formatting

* style formatting after update

* Update src/transformers/models/whisper/tokenization_whisper_fast.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* fixing _strip_prompt to handle _decode_with_timestamps

* fix copies

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
Aviv Shamsian
2024-07-12 22:07:10 +03:00
committed by GitHub
parent d1a1bcf56a
commit 7f79a97399
3 changed files with 76 additions and 6 deletions

View File

@@ -14,6 +14,8 @@
import unittest
import numpy as np
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
from transformers.testing_utils import slow
@@ -251,6 +253,39 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(tokenizer_prompt_ids.tolist(), fast_tokenizer_prompt_ids.tolist())
def test_tokenizer_decode_prompt(self):
prompt_text = "What does the fox say?"
input_text = "Hatee hatee hatee ho"
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()
# encode prompt and input text using tokenizer
prompt_ids = tokenizer.get_prompt_ids(prompt_text, return_tensors="np")
input_ids = tokenizer(input_text, return_tensors="np").input_ids[0]
input_ids = np.hstack([prompt_ids, input_ids])
# encode using fast tokenizer
rust_prompt_ids = rust_tokenizer.get_prompt_ids(prompt_text, return_tensors="np")
rust_input_ids = rust_tokenizer(input_text, return_tensors="np").input_ids[0]
rust_input_ids = np.hstack([rust_prompt_ids, rust_input_ids])
# check with prompt in output
pred_text = tokenizer.decode(input_ids, skip_special_tokens=False)
rust_pred_text = rust_tokenizer.decode(rust_input_ids, skip_special_tokens=False)
# check correctness for both tokenizers
expected_text = f"<|startofprev|> {prompt_text}<|startoftranscript|><|notimestamps|>{input_text}<|endoftext|>"
self.assertEqual(pred_text.strip(), expected_text)
self.assertEqual(rust_pred_text.strip(), expected_text)
# check stripping prompt from output
pred_text = tokenizer.decode(input_ids, skip_special_tokens=True)
rust_pred_text = tokenizer.decode(input_ids, skip_special_tokens=True)
self.assertEqual(pred_text.strip(), input_text)
self.assertEqual(rust_pred_text.strip(), input_text)
def test_combine_tokens_into_words(self):
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()