feat: Whisper prompting (#22496)
* initial working additions * clean and rename, add cond stripping initial prompt to decode * cleanup, edit create_initial_prompt_ids, add tests * repo consistency, flip order of conditional * fix error, move the processor fn to the tokenizer * repo consistency, update test ids to corresponding tokenizer * use convert_tokens_to_ids not get_vocab... * use actual conditional in generate * make sytle * initial address comments * initial working add new params to pipeline * first draft of sequential generation for condition_on_previous_text * add/update tests, make compatible with timestamps * make compatible with diff. input kwargs and max length * add None check * add temperature check * flip temp check operand * refocusing to prev pr scope * remove the params too * make style * edits, move max length incorporating prompt to whisper * address comments * remove asr pipeline prompt decoding, fix indexing * address comments (more tests, validate prompt) * un-comment out tests (from debug) * remove old comment * address comments * fix typo * remove timestamp token from test * make style * cleanup * copy method to fast tokenizer, set max_new_tokens for test * prompt_ids type just pt * address Amy's comments * make style
This commit is contained in:
@@ -194,6 +194,25 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
merge = _find_longest_common_sequence([seq1, seq2, seq3])
|
||||
self.assertEqual(merge, [1, 2, 3, 4, 5, 6, 7, 8])
|
||||
|
||||
def test_skip_special_tokens_skips_prompt_ids(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
rust_tokenizer = self.get_rust_tokenizer()
|
||||
# fmt: off
|
||||
encoded_input = [
|
||||
50361, 2221, 13, 2326, 388, 391, 50258, 50259, 50359,
|
||||
50363, 1282, 264, 2674, 9156, 295, 1523, 11, 2221, 13,
|
||||
2326, 388, 391, 13657, 365, 2681, 21296, 17711, 13, 50257,
|
||||
]
|
||||
# fmt: on
|
||||
expected_with_special_tokens = "<|startofprev|> Mr. Quilter<|startoftranscript|><|en|><|transcribe|><|notimestamps|> On the general principles of art, Mr. Quilter writes with equal lucidity.<|endoftext|>"
|
||||
expected_without_special_tokens = " On the general principles of art, Mr. Quilter writes with equal lucidity."
|
||||
self.assertEqual(tokenizer.decode(encoded_input, skip_special_tokens=False), expected_with_special_tokens)
|
||||
self.assertEqual(tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens)
|
||||
self.assertEqual(rust_tokenizer.decode(encoded_input, skip_special_tokens=False), expected_with_special_tokens)
|
||||
self.assertEqual(
|
||||
rust_tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens
|
||||
)
|
||||
|
||||
|
||||
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||
checkpoint_name = "openai/whisper-small.en"
|
||||
|
||||
Reference in New Issue
Block a user