feat: Whisper prompting (#22496)

* initial working additions

* clean and rename, add cond stripping initial prompt to decode

* cleanup, edit create_initial_prompt_ids, add tests

* repo consistency, flip order of conditional

* fix error, move the processor fn to the tokenizer

* repo consistency, update test ids to corresponding tokenizer

* use convert_tokens_to_ids not get_vocab...

* use actual conditional in generate

* make sytle

* initial address comments

* initial working add new params to pipeline

* first draft of sequential generation for condition_on_previous_text

* add/update tests, make compatible with timestamps

* make compatible with diff. input kwargs and max length

* add None check

* add temperature check

* flip temp check operand

* refocusing to prev pr scope

* remove the params too

* make style

* edits, move max length incorporating prompt to whisper

* address comments

* remove asr pipeline prompt decoding, fix indexing

* address comments (more tests, validate prompt)

* un-comment out tests (from debug)

* remove old comment

* address comments

* fix typo

* remove timestamp token from test

* make style

* cleanup

* copy method to fast tokenizer, set max_new_tokens for test

* prompt_ids type just pt

* address Amy's comments

* make style
This commit is contained in:
Connor Henderson
2023-05-19 04:33:11 -04:00
committed by GitHub
parent a7920065f2
commit 2acedf4721
7 changed files with 272 additions and 15 deletions

View File

@@ -194,6 +194,25 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
merge = _find_longest_common_sequence([seq1, seq2, seq3])
self.assertEqual(merge, [1, 2, 3, 4, 5, 6, 7, 8])
def test_skip_special_tokens_skips_prompt_ids(self):
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()
# fmt: off
encoded_input = [
50361, 2221, 13, 2326, 388, 391, 50258, 50259, 50359,
50363, 1282, 264, 2674, 9156, 295, 1523, 11, 2221, 13,
2326, 388, 391, 13657, 365, 2681, 21296, 17711, 13, 50257,
]
# fmt: on
expected_with_special_tokens = "<|startofprev|> Mr. Quilter<|startoftranscript|><|en|><|transcribe|><|notimestamps|> On the general principles of art, Mr. Quilter writes with equal lucidity.<|endoftext|>"
expected_without_special_tokens = " On the general principles of art, Mr. Quilter writes with equal lucidity."
self.assertEqual(tokenizer.decode(encoded_input, skip_special_tokens=False), expected_with_special_tokens)
self.assertEqual(tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens)
self.assertEqual(rust_tokenizer.decode(encoded_input, skip_special_tokens=False), expected_with_special_tokens)
self.assertEqual(
rust_tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens
)
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
checkpoint_name = "openai/whisper-small.en"