[Whisper] 🚨 Fix whisper decoding 🚨 (#34135)
* do not remove decoder_input_ids for the first segment * do not remove eos token in generate_with_fallback * when removing padding tokens, do not remove eos token * remove eos token in generate (and not in generate_with_fallback!) * reconciliate short-from/ long-form behavior * correct avg_logprobs calculation * handle eos token in segments * handle decoder_input_ids and eos token in _prepare_decoder_input_ids * fix incorrect time precision * always remove eos token * always remove decoder_input_ids * no need to handle decoder_inputs_ids and eos token * no need to remove decoder_input_ids * no need to handle eos token * fix num_beams in _retrieve_logit_processors * remove todo unconsistency * no need to add eos token * last_timestamp_pos should indeed be timestamp token pos * patch generate to enable compatibility with GenerationTesterMixin tests * adapt test_generate_continue_from_past_key_values * adapt test_prompt_lookup_decoding_matches_greedy_search * adapt generic GenerationMixin tests to whisper's generate * fix speculative decoding * fix * [run-slow] whisper * change HF_HUB_TOKEN for require_read_token * [run-slow] whisper * prioritize kwargs over generation_config * remove unnecessary args * [run-slow] whisper * update tests * [run-slow] whisper * add comment * update test * [run-slow] whisper * update test + revert require_read_token * docstring updates * revert tokenizer decode args change * do not use a patch + docstring updates * [run-slow] whisper * make * [run-slow] whisper * add a flag to force unique call to generate * test update * [run-slow] whisper * add force_unique_generate_call arg * do not use a patch * correct the timestamps for the pad tokens * docstring update * docstring update * docstring update * upodate TF tests * add require_read_token * [run-slow] whisper * test reset dynamo * [run-slow] whisper * fix * [run-slow] whisper * avoid iterating twice on current_segments * [run-slow] whisper * [run-slow] whisper --------- Co-authored-by: Eustache Le Bihan <eustlb@users.noreply.huggingface.co> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -17,14 +17,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import tempfile
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import WhisperConfig, WhisperFeatureExtractor, WhisperProcessor
|
||||
from transformers.testing_utils import is_tf_available, require_tf, require_tokenizers, run_test_in_subprocess, slow
|
||||
from transformers import GenerationConfig, WhisperConfig, WhisperFeatureExtractor, WhisperProcessor
|
||||
from transformers.testing_utils import (
|
||||
is_tf_available,
|
||||
require_read_token,
|
||||
require_tf,
|
||||
require_tokenizers,
|
||||
run_test_in_subprocess,
|
||||
slow,
|
||||
)
|
||||
from transformers.utils import cached_property
|
||||
from transformers.utils.import_utils import is_datasets_available
|
||||
|
||||
@@ -749,7 +757,9 @@ def _test_large_generation(in_queue, out_queue, timeout):
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
|
||||
generated_ids = model.generate(
|
||||
input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
|
||||
input_features,
|
||||
do_sample=False,
|
||||
max_length=20,
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
@@ -772,13 +782,29 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
|
||||
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
|
||||
|
||||
ds = load_dataset("legacy-datasets/common_voice", "ja", split="test", streaming=True, trust_remote_code=True)
|
||||
# update generation config
|
||||
generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2")
|
||||
|
||||
token = os.getenv("HF_HUB_READ_TOKEN", True)
|
||||
ds = load_dataset(
|
||||
"mozilla-foundation/common_voice_6_1",
|
||||
"ja",
|
||||
split="test",
|
||||
streaming=True,
|
||||
trust_remote_code=True,
|
||||
token=token,
|
||||
)
|
||||
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
|
||||
input_speech = next(iter(ds))["audio"]["array"]
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
|
||||
generated_ids = model.generate(
|
||||
input_features, do_sample=False, max_length=20, language="<|ja|>", task="transcribe"
|
||||
input_features,
|
||||
do_sample=False,
|
||||
max_length=20,
|
||||
language="<|ja|>",
|
||||
task="transcribe",
|
||||
generation_config=generation_config,
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
@@ -786,7 +812,12 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout):
|
||||
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
generated_ids = model.generate(
|
||||
input_features, do_sample=False, max_length=20, language="<|en|>", task="transcribe"
|
||||
input_features,
|
||||
do_sample=False,
|
||||
max_length=20,
|
||||
language="<|en|>",
|
||||
task="transcribe",
|
||||
generation_config=generation_config,
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
@@ -794,7 +825,12 @@ def _test_large_generation_multilingual(in_queue, out_queue, timeout):
|
||||
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
generated_ids = model.generate(
|
||||
input_features, do_sample=False, max_length=20, language="<|ja|>", task="translate"
|
||||
input_features,
|
||||
do_sample=False,
|
||||
max_length=20,
|
||||
language="<|ja|>",
|
||||
task="translate",
|
||||
generation_config=generation_config,
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
@@ -825,10 +861,10 @@ def _test_large_batched_generation(in_queue, out_queue, timeout):
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_IDS = [
|
||||
[50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],
|
||||
[50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],
|
||||
[50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],
|
||||
[50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]
|
||||
[50258, 50259, 50359, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404],
|
||||
[50258, 50259, 50359, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257],
|
||||
[50258, 50259, 50359, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904],
|
||||
[50258, 50259, 50359, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439]
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
@@ -836,10 +872,10 @@ def _test_large_batched_generation(in_queue, out_queue, timeout):
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
" Mr. Quilter is the apostle of the middle classes and we are glad to",
|
||||
" Mr. Quilter is the apostle of the middle classes and we are glad",
|
||||
" Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||
" He tells us that at this festive season of the year, with Christmas and roast beef",
|
||||
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,"
|
||||
" He tells us that at this festive season of the year, with Christmas and roast",
|
||||
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all"
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
@@ -1009,6 +1045,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_large_generation, inputs=None)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_large_generation_multilingual(self):
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_large_generation_multilingual, inputs=None)
|
||||
|
||||
|
||||
@@ -445,6 +445,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
self.config_tester = ConfigTester(self, config_class=WhisperConfig)
|
||||
self.maxDiff = 3000
|
||||
|
||||
def prepare_config_and_inputs_for_generate(self, batch_size=2):
|
||||
config, inputs_dict = super().prepare_config_and_inputs_for_generate(batch_size=batch_size)
|
||||
inputs_dict["force_unique_generate_call"] = True
|
||||
return config, inputs_dict
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@@ -1891,8 +1896,8 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
"ja",
|
||||
split="test",
|
||||
streaming=True,
|
||||
token=token,
|
||||
trust_remote_code=True,
|
||||
token=token,
|
||||
)
|
||||
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
|
||||
|
||||
@@ -2144,11 +2149,16 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
},
|
||||
{
|
||||
"text": " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and",
|
||||
"timestamp": (39.80, 45.36),
|
||||
# "timestamp": (39.80, 45.36),
|
||||
# above is the expected output on A100.
|
||||
# on CI T4s, due to sligth difference in floating points operations, expected is below
|
||||
"timestamp": (39.80, 45.38),
|
||||
},
|
||||
{
|
||||
"text": " can discover in it but little of rocky Ithaca.",
|
||||
"timestamp": (45.36, 49.0),
|
||||
# "timestamp": (45.36, 49.0),
|
||||
# see above
|
||||
"timestamp": (45.38, 49.0),
|
||||
},
|
||||
{
|
||||
"text": " Lenell's pictures are a sort of up-guards-and-atom paintings, and Mason's exquisite ittles",
|
||||
@@ -2275,20 +2285,20 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_OUTPUT = torch.tensor([
|
||||
[0.0000, 0.4800, 0.8200, 0.9600, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0000, 2.3400, 2.5000, 2.6600, 3.1800, 3.5600, 3.6800, 3.8000, 4.1000, 4.3000, 4.5800, 4.9400, 5.3800, 12.4200, 12.8400, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9400, 26.9400, 26.9400, 26.9400],
|
||||
[0.0000, 0.5200, 0.9000, 1.1400, 1.4200, 1.5200, 1.6800, 1.6800, 1.8800, 2.1000, 2.2200, 2.6200, 3.1400, 3.5800, 3.9600, 4.4000, 17.3000, 17.3000, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400],
|
||||
[0.0000, 0.4800, 0.8200, 0.9600, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0000, 2.3400, 2.5000, 2.6600, 3.1800, 3.5600, 3.6800, 3.8000, 4.1000, 4.3000, 4.5800, 4.9400, 5.3800, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200, 12.4200],
|
||||
[0.0000, 0.5200, 0.9000, 1.1400, 1.4200, 1.5200, 1.6800, 1.6800, 1.8800, 2.1000, 2.2200, 2.6200, 3.1400, 3.5800, 3.9600, 4.4000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000, 17.3000],
|
||||
[0.0000, 0.0000, 0.7600, 1.0000, 1.4200, 1.8000, 1.9400, 2.1800, 2.5200, 3.0200, 3.3200, 3.5400, 3.9400, 4.5600, 4.9200, 5.2800, 5.5600, 5.9000, 6.1600, 6.3000, 6.4800, 6.4800, 6.6400, 7.8200, 7.9600, 8.2200, 8.6000, 8.9200, 9.2200, 9.5200, 9.7200, 10.0600, 10.5400, 10.8800, 11.2600, 11.5400, 11.7400, 12.0800, 15.6800],
|
||||
[0.0000, 0.0000, 0.7400, 1.0400, 1.3200, 1.6800, 2.1400, 2.4800, 2.7800, 3.0800, 3.1600, 3.4000, 3.6000, 4.0200, 4.2200, 4.8600, 5.2400, 5.7400, 6.3400, 6.6200, 6.7600, 6.7600, 6.8600, 7.2400, 7.4200, 7.6800, 7.9200, 8.4800, 8.7600, 9.2000, 9.2000, 9.4200, 15.8200, 15.8200, 29.6400, 29.6600, 29.6600, 29.6600, 29.6600]
|
||||
[0.0000, 0.0000, 0.7400, 1.0400, 1.3200, 1.6800, 2.1400, 2.4800, 2.7800, 3.0800, 3.1600, 3.4000, 3.6000, 4.0200, 4.2200, 4.8600, 5.2400, 5.7400, 6.3400, 6.6200, 6.7600, 6.7600, 6.8600, 7.2400, 7.4200, 7.6800, 7.9200, 8.4800, 8.7600, 9.2000, 9.2000, 9.4200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200, 15.8200]
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(generate_outputs["token_timestamps"].to("cpu"), EXPECTED_OUTPUT))
|
||||
|
||||
@slow
|
||||
def test_large_token_timestamp_generation(self):
|
||||
def test_small_token_timestamp_generation(self):
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3")
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
|
||||
model.to(torch_device)
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
@@ -2305,10 +2315,10 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_OUTPUT = torch.tensor([
|
||||
[0.0000, 0.0000, 0.6200, 0.7400, 0.8600, 1.0000, 1.0400, 1.3000, 1.4400, 1.7800, 2.1800, 2.2800, 2.5000, 2.9200, 3.0000, 3.3800, 3.5000, 3.6000, 3.8400, 4.1000, 4.4000, 4.6800, 5.1400, 5.3600, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200],
|
||||
[0.0000, 0.0000, 0.6000, 0.9200, 1.2200, 1.3400, 1.4200, 1.5400, 1.5800, 1.7400, 2.0600, 2.3800, 3.0400, 3.3800, 3.6400, 4.1200, 4.3600, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800],
|
||||
[0.0000, 0.0000, 0.5400, 0.8200, 1.1600, 1.4600, 1.7400, 1.8800, 2.3400, 2.7400, 3.1400, 3.2200, 3.5400, 4.2800, 4.5600, 4.8200, 5.0600, 5.3200, 5.6600, 5.9600, 6.1400, 6.4000, 6.8400, 7.8800, 8.0200, 8.3600, 8.7000, 9.0200, 9.3200, 9.5000, 9.8400, 10.3000, 10.6600, 11.0800, 11.3600, 11.4600, 11.8000],
|
||||
[0.0000, 0.0000, 0.5600, 0.7600, 1.0600, 1.4000, 1.8800, 2.2600, 2.6200, 2.8000, 2.9600, 3.0000, 3.2000, 3.4400, 3.6800, 4.0000, 4.6000, 5.0000, 5.3200, 5.4800, 6.0600, 6.0600, 6.1000, 6.3200, 6.7400, 7.0000, 7.2200, 7.4000, 7.7600, 8.0600, 8.5600, 8.8600, 8.9400, 9.1000, 9.3400, 9.8800, 9.8800]
|
||||
[0.0000, 0.0000, 0.7400, 0.8000, 0.9800, 1.0200, 1.1400, 1.4000, 1.5200, 1.9200, 2.2600, 2.3800, 2.5400, 2.8600, 3.2600, 3.3400, 3.4400, 3.6000, 3.6800, 3.9200, 4.2000, 4.4800, 4.7800, 5.2600, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200],
|
||||
[0.0000, 0.0000, 0.7600, 1.0000, 1.3000, 1.3800, 1.5200, 1.5800, 1.7000, 1.8400, 2.1000, 2.5000, 3.1400, 3.4400, 3.7400, 4.1800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800],
|
||||
[0.0000, 0.0000, 0.6600, 0.9000, 1.2200, 1.5200, 1.7600, 2.0200, 2.4000, 2.9200, 3.1800, 3.3200, 3.6200, 4.1000, 4.3600, 4.7800, 5.1200, 5.3400, 5.7200, 6.0600, 6.2000, 6.2000, 6.2000, 6.5000, 6.9000, 7.6400, 8.0000, 8.2400, 8.5200, 8.7400, 9.0800, 9.4000, 9.5400, 9.9400, 10.4200, 10.7600, 11.1200, 11.4400, 11.5800, 11.8600, 12.4600],
|
||||
[0.0000, 0.0000, 0.6600, 0.8600, 1.1400, 1.5000, 1.9600, 2.3600, 2.6400, 2.9800, 3.1200, 3.2400, 3.4800, 3.7800, 4.1400, 4.6400, 5.0800, 5.4400, 6.2200, 6.2200, 6.2200, 6.4000, 6.8400, 7.1200, 7.2600, 7.4800, 7.8200, 8.1400, 8.7000, 9.0200, 9.0200, 9.2000, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800, 9.8800]
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
@@ -3331,6 +3341,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
# only permit 4 compilations: 2 prefill steps and 2 decoding steps (1 for each of conditioned/not conditioned)
|
||||
torch._dynamo.config.cache_size_limit = 4
|
||||
torch._dynamo.reset()
|
||||
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||
|
||||
Reference in New Issue
Block a user