[Whisper] Strip prompt before finding common subsequence (#27836)

This commit is contained in:
Sanchit Gandhi
2024-05-22 17:25:47 +01:00
committed by GitHub
parent b1065aa08a
commit 0948c827de
2 changed files with 40 additions and 0 deletions

View File

@@ -897,11 +897,15 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
right_stride_start = None
all_special_ids = set(tokenizer.all_special_ids)
prompt_token_id = tokenizer.convert_tokens_to_ids("<|startofprev|>")
decoder_start_token_id = tokenizer.convert_tokens_to_ids("<|startoftranscript|>")
# - iterate over all outputs
for chunk_id, output in enumerate(model_outputs):
# We can drop everything to Python list, it's going to make
# our lives easier
token_ids = output["tokens"][0].tolist()
# (possibly) remove the prompt from the token ids
token_ids = tokenizer._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
if return_timestamps == "word":
token_timestamps = output["token_timestamps"][0].tolist()

View File

@@ -1343,6 +1343,42 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "<s> <s")
@require_torch
@slow
def test_whisper_prompted(self):
processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model = model.to("cuda")
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=30,
batch_size=16,
device="cuda:0",
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
# prompt the model to misspell "Mr Quilter" as "Mr Quillter"
whisper_prompt = "Mr. Quillter."
prompt_ids = pipe.tokenizer.get_prompt_ids(whisper_prompt, return_tensors="pt")
unprompted_result = pipe(sample.copy())["text"]
prompted_result = pipe(sample, generate_kwargs={"prompt_ids": prompt_ids})["text"]
# fmt: off
EXPECTED_UNPROMPTED_RESULT = " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of rocky Ithaca. Lennils, pictures are a sort of upguards and atom paintings and Mason's exquisite itals are as national as a jingo poem. Mr. Birkut Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap on the back before he says like a shampoo or a Turkish bath. Next man"
EXPECTED_PROMPTED_RESULT = " Mr. Quillter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quillter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins work is really great after all, and can discover in it but little of rocky Ithaca. Lennils, pictures are a sort of upguards and atom paintings, and Mason's exquisite itals are as national as a jingo poem. Mr. Birkut Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. Mr. John Collier gives his sitter a cheerful slap on the back before he says like a shampoo or a Turkish bath. Next man."
# fmt: on
self.assertEqual(unprompted_result, EXPECTED_UNPROMPTED_RESULT)
self.assertEqual(prompted_result, EXPECTED_PROMPTED_RESULT)
@require_torch
@slow
def test_whisper_longform(self):