feat: Whisper prompting (#22496)
* initial working additions * clean and rename, add cond stripping initial prompt to decode * cleanup, edit create_initial_prompt_ids, add tests * repo consistency, flip order of conditional * fix error, move the processor fn to the tokenizer * repo consistency, update test ids to corresponding tokenizer * use convert_tokens_to_ids not get_vocab... * use actual conditional in generate * make sytle * initial address comments * initial working add new params to pipeline * first draft of sequential generation for condition_on_previous_text * add/update tests, make compatible with timestamps * make compatible with diff. input kwargs and max length * add None check * add temperature check * flip temp check operand * refocusing to prev pr scope * remove the params too * make style * edits, move max length incorporating prompt to whisper * address comments * remove asr pipeline prompt decoding, fix indexing * address comments (more tests, validate prompt) * un-comment out tests (from debug) * remove old comment * address comments * fix typo * remove timestamp token from test * make style * cleanup * copy method to fast tokenizer, set max_new_tokens for test * prompt_ids type just pt * address Amy's comments * make style
This commit is contained in:
@@ -16,6 +16,8 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import WhisperTokenizer, is_speech_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_torch, require_torchaudio
|
||||
|
||||
@@ -146,3 +148,32 @@ class WhisperProcessorTest(unittest.TestCase):
|
||||
|
||||
expected_ids = [TRANSCRIBE, NOTIMESTAMPS]
|
||||
self.assertListEqual([ids[-1] for ids in forced_decoder_ids], expected_ids)
|
||||
|
||||
def test_get_prompt_ids(self):
|
||||
processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
|
||||
prompt_ids = processor.get_prompt_ids("Mr. Quilter")
|
||||
decoded_prompt = processor.tokenizer.decode(prompt_ids)
|
||||
|
||||
self.assertListEqual(prompt_ids.tolist(), [50360, 1770, 13, 2264, 346, 353])
|
||||
self.assertEqual(decoded_prompt, "<|startofprev|> Mr. Quilter")
|
||||
|
||||
def test_empty_get_prompt_ids(self):
|
||||
processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
|
||||
prompt_ids = processor.get_prompt_ids("")
|
||||
decoded_prompt = processor.tokenizer.decode(prompt_ids)
|
||||
|
||||
self.assertListEqual(prompt_ids.tolist(), [50360, 220])
|
||||
self.assertEqual(decoded_prompt, "<|startofprev|> ")
|
||||
|
||||
def test_get_prompt_ids_with_special_tokens(self):
|
||||
processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
|
||||
|
||||
def _test_prompt_error_raised_helper(prompt, special_token):
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
processor.get_prompt_ids(prompt)
|
||||
expected = f"Encountered text in the prompt corresponding to disallowed special token: {special_token}."
|
||||
self.assertEqual(expected, str(excinfo.value))
|
||||
|
||||
_test_prompt_error_raised_helper("<|startofprev|> test", "<|startofprev|>")
|
||||
_test_prompt_error_raised_helper("test <|notimestamps|>", "<|notimestamps|>")
|
||||
_test_prompt_error_raised_helper("test <|zh|> test <|transcribe|>", "<|zh|>")
|
||||
|
||||
Reference in New Issue
Block a user