[Whisper] Strip prompt before finding common subsequence (#27836)
This commit is contained in:
@@ -897,11 +897,15 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
|
||||
right_stride_start = None
|
||||
|
||||
all_special_ids = set(tokenizer.all_special_ids)
|
||||
prompt_token_id = tokenizer.convert_tokens_to_ids("<|startofprev|>")
|
||||
decoder_start_token_id = tokenizer.convert_tokens_to_ids("<|startoftranscript|>")
|
||||
# - iterate over all outputs
|
||||
for chunk_id, output in enumerate(model_outputs):
|
||||
# We can drop everything to Python list, it's going to make
|
||||
# our lives easier
|
||||
token_ids = output["tokens"][0].tolist()
|
||||
# (possibly) remove the prompt from the token ids
|
||||
token_ids = tokenizer._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
|
||||
if return_timestamps == "word":
|
||||
token_timestamps = output["token_timestamps"][0].tolist()
|
||||
|
||||
|
||||
@@ -1343,6 +1343,42 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
self.assertEqual(output, [{"text": ANY(str)}])
|
||||
self.assertEqual(output[0]["text"][:6], "<s> <s")
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_whisper_prompted(self):
|
||||
processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||
model = model.to("cuda")
|
||||
|
||||
pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model=model,
|
||||
tokenizer=processor.tokenizer,
|
||||
feature_extractor=processor.feature_extractor,
|
||||
max_new_tokens=128,
|
||||
chunk_length_s=30,
|
||||
batch_size=16,
|
||||
device="cuda:0",
|
||||
)
|
||||
|
||||
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
|
||||
sample = dataset[0]["audio"]
|
||||
|
||||
# prompt the model to misspell "Mr Quilter" as "Mr Quillter"
|
||||
whisper_prompt = "Mr. Quillter."
|
||||
prompt_ids = pipe.tokenizer.get_prompt_ids(whisper_prompt, return_tensors="pt")
|
||||
|
||||
unprompted_result = pipe(sample.copy())["text"]
|
||||
prompted_result = pipe(sample, generate_kwargs={"prompt_ids": prompt_ids})["text"]
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_UNPROMPTED_RESULT = " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of rocky Ithaca. Lennils, pictures are a sort of upguards and atom paintings and Mason's exquisite itals are as national as a jingo poem. Mr. Birkut Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap on the back before he says like a shampoo or a Turkish bath. Next man"
|
||||
EXPECTED_PROMPTED_RESULT = " Mr. Quillter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quillter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins work is really great after all, and can discover in it but little of rocky Ithaca. Lennils, pictures are a sort of upguards and atom paintings, and Mason's exquisite itals are as national as a jingo poem. Mr. Birkut Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. Mr. John Collier gives his sitter a cheerful slap on the back before he says like a shampoo or a Turkish bath. Next man."
|
||||
# fmt: on
|
||||
|
||||
self.assertEqual(unprompted_result, EXPECTED_UNPROMPTED_RESULT)
|
||||
self.assertEqual(prompted_result, EXPECTED_PROMPTED_RESULT)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_whisper_longform(self):
|
||||
|
||||
Reference in New Issue
Block a user