Support generating with fallback for short form audio in Whisper (#30984)
* remove is_shortform * adapt _retrieve_max_frames_and_seek for short_form * return bos token in short and long form * add decoder_input_ids to short form audios * add eos token for short form * handle short form token_timestamps * no need to return scores * add is_shortform conditions * handle when max_new_tokens is None - short form * handle assistant decoding * fix * handle return_dict_in_generate * handle split_by_batch for encoder_attentions attribute * handle num_beams>1 * handle num_return_sequences>1 in generate_with_fallback * handle num_return_sequences>1 with return_dict_in_generate=True * raise error if max_new_tokens + decoder_inputs_ids > max_target_pos * fix * apply review suggestions * fix * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * fix * logits for both short form and long form * handle if logits_processor is None * test * apply review changes to num_return_sequences * add _expand_variables_for_generation * remove short form commented section * update comments * uncomment num_beams line in generate_with_fallback * update assistant decoding * handle return_segment with short form generation * up * fix output format is_shortform * overwrite beam_sample test * update _set_return_timestamps * apply review suggestions * apply review suggestions * remove seek_outputs_short_form * fix _stack_split_outputs * fix stack dim in _stack_split_outputs * update tests * fix past_key_values + beam tests * fix * clean _expand_variables_for_generation * make style * fix slow tests * make style * max_length condition * make style * add slow tests for shortform fallback * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * apply review changes * Update src/transformers/models/whisper/generation_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * up * fix slow tests * apply review suggestions * update test * make style * small fix * fix * fix test_new_cache_format * fix past_key_values * fix * make style * fix slow tests * fix --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
@@ -65,6 +65,15 @@ if is_torch_available():
|
||||
WhisperProcessor,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.generation import (
|
||||
BeamSampleDecoderOnlyOutput,
|
||||
BeamSampleEncoderDecoderOutput,
|
||||
BeamSearchDecoderOnlyOutput,
|
||||
BeamSearchEncoderDecoderOutput,
|
||||
GenerateBeamDecoderOnlyOutput,
|
||||
GenerateBeamEncoderDecoderOutput,
|
||||
PhrasalConstraint,
|
||||
)
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, sinusoids
|
||||
|
||||
@@ -1539,6 +1548,241 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
def test_longform_generate_multi_batch_cond_prev(self):
|
||||
self._check_longform_generate_multi_batch(condition_on_prev_tokens=True)
|
||||
|
||||
def test_beam_sample_generate_dict_output(self):
|
||||
# We overwrite test_beam_sample_generate_dict_output in test_utils as
|
||||
# we can only perform beam search if the temperature is set to 0 in Whisper.
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
# disable cache
|
||||
config.use_cache = False
|
||||
|
||||
model = WhisperForConditionalGeneration(config).to(torch_device).eval()
|
||||
_, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1])
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
|
||||
# With Whisper, we can only perform a beam search if the temperature is set to 0.
|
||||
logits_warper_kwargs["temperature"] = 0
|
||||
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
|
||||
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
|
||||
|
||||
output_generate = self._beam_sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_warper_kwargs=logits_warper_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
|
||||
|
||||
self._check_outputs(output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"])
|
||||
|
||||
def test_beam_search_generate_dict_output(self):
|
||||
# We overwrite test_beam_search_generate_dict_output in test_utils as
|
||||
# we can only perform beam search if the temperature is set to 0 in Whisper.
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
# disable cache
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
)
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
|
||||
# With Whisper, we can only perform a beam search if the temperature is set to 0.
|
||||
logits_process_kwargs["temperature"] = 0
|
||||
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
|
||||
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
|
||||
|
||||
output_generate = self._beam_search_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
self._check_outputs(
|
||||
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
||||
)
|
||||
|
||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||
# We overwrite test_beam_search_generate_dict_outputs_use_cache in test_utils as
|
||||
# we can only perform beam search if the temperature is set to 0 in Whisper.
|
||||
for model_class in self.all_generative_model_classes:
|
||||
# enable cache
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
if not hasattr(config, "use_cache"):
|
||||
self.skipTest("This model doesn't support caching")
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
|
||||
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
|
||||
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
output_generate = self._beam_search_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
else:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
self._check_outputs(
|
||||
output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_kwargs["num_beams"]
|
||||
)
|
||||
|
||||
def test_group_beam_search_generate_dict_output(self):
|
||||
# We overwrite test_group_beam_search_generate_dict_output in test_utils as
|
||||
# we can only perform beam search if the temperature is set to 0 in Whisper.
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
beam_kwargs = self._get_diverse_beam_kwargs()
|
||||
|
||||
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
|
||||
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
|
||||
|
||||
output_generate = self._group_beam_search_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
self._check_outputs(
|
||||
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
||||
)
|
||||
|
||||
def test_constrained_beam_search_generate_dict_output(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
# disable cache
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
# Sample constraints
|
||||
min_id = 3
|
||||
max_id = model.config.vocab_size
|
||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
||||
constraints = [
|
||||
PhrasalConstraint(force_tokens),
|
||||
]
|
||||
|
||||
beam_kwargs = self._get_constrained_beam_kwargs()
|
||||
output_generate = self._constrained_beam_search_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
constraints=constraints,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||
else:
|
||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||
# Retrocompatibility check
|
||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||
|
||||
self._check_outputs(
|
||||
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_return_sequences"]
|
||||
)
|
||||
|
||||
def test_custom_4d_attention_mask(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = WhisperForConditionalGeneration(config).to(device=torch_device, dtype=torch.float32)
|
||||
@@ -2680,6 +2924,55 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert decoded == EXPECTED_TEXT
|
||||
|
||||
@slow
|
||||
def test_whisper_shortform_single_batch_prev_cond(self):
|
||||
# fmt: off
|
||||
EXPECTED_TEXT = [" Folks, I spend a lot of time right over there, night after night, actually. Carefully selecting for you the day's newsiest, most aerodynamic headlines, stress testing and the most topical antilock breaks and power steering pain, Stakingly stitching, leather seating so soft, it would make JD power and her associate blush. If you were to create the luxury sedan that is my nightly model, but sometimes— you're sometimes, folks— I lurched the consciousness and the back of an abandoned school bus"]
|
||||
EXPECTED_TEXT1 = [" Folks, I spend a lot of time right over there night after night after, actually. Carefully selecting for you the day's noisiest, most aerodynamic headlines, stress testing, and the most topical, anti-lock breaks and power steering, painstakingly stitching, leather seating, so soft, it would make JD power and her associates blush to create the luxury sedan that is my nightly monologue. But sometimes, you sometimes, folks. I lurched a consciousness in the back of an abandoned school"]
|
||||
# fmt: on
|
||||
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||
model = model.to(torch_device)
|
||||
|
||||
ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
|
||||
dataset = ds.cast_column("audio", Audio(sampling_rate=16000))
|
||||
|
||||
one_audio = dataset[1]["audio"]["array"]
|
||||
|
||||
input_features = processor(one_audio, return_tensors="pt", sampling_rate=16_000)["input_features"]
|
||||
input_features = input_features.to(device=torch_device)
|
||||
|
||||
gen_kwargs = {
|
||||
"return_timestamps": True,
|
||||
"no_speech_threshold": 0.6,
|
||||
"temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||
"compression_ratio_threshold": 1.35,
|
||||
"condition_on_prev_tokens": True,
|
||||
"logprob_threshold": -1.0,
|
||||
}
|
||||
|
||||
torch.manual_seed(0)
|
||||
result = model.generate(input_features, **gen_kwargs)
|
||||
decoded = processor.batch_decode(result.sequences, skip_special_tokens=True)
|
||||
|
||||
assert decoded == EXPECTED_TEXT
|
||||
|
||||
gen_kwargs = {
|
||||
"return_timestamps": True,
|
||||
"no_speech_threshold": 0.3,
|
||||
"temperature": (0.0, 0.2),
|
||||
"compression_ratio_threshold": 1,
|
||||
"condition_on_prev_tokens": False,
|
||||
"logprob_threshold": -1.0,
|
||||
}
|
||||
|
||||
torch.manual_seed(0)
|
||||
result = model.generate(input_features, **gen_kwargs)
|
||||
decoded = processor.batch_decode(result.sequences, skip_special_tokens=True)
|
||||
|
||||
assert decoded == EXPECTED_TEXT1
|
||||
|
||||
@slow
|
||||
def test_whisper_longform_single_batch_beam(self):
|
||||
# fmt: off
|
||||
@@ -2931,6 +3224,57 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
elif isinstance(EXPECTED_TEXT[i], tuple):
|
||||
assert decoded_all[i] in EXPECTED_TEXT[i]
|
||||
|
||||
@slow
|
||||
def test_whisper_shortform_multi_batch_hard_prev_cond(self):
|
||||
# Without this set here, this test may fail if it is run with other tests (say, `test_tiny_*`). It's unclear
|
||||
# why other tests may affect this tests: it seems some random operations are beyond the scene.
|
||||
set_seed(0)
|
||||
# fmt: off
|
||||
EXPECTED_TEXT = [
|
||||
' Mr. Kfilter is the apostle of the Middle Classes and we are glad to welcome his gospel.',
|
||||
" Nor is Mr. Qilter's manner less interesting than his matter.",
|
||||
' He tells us that at this festive season of the year, with Christmas and roce beef, looming before us, similarly drawn from eating and its results occur most readily to the mind.',
|
||||
' He has grabbed those with her surfered trigger late and his work is really a great after all, and can discover it in it but little of Rocky Ithaka.',
|
||||
" L'Neile's pictures are a sort of upguards and add-um paintings, and Maessin's exquisite Itals are a national as a jingo poem. Mr. Birkett Foster's landscapes smiled at one much in the same way that Mr. Carcher used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slapper in the back, before he says,",
|
||||
' It is obviously unnecessary for us, to point out how luminous these criticisms are, how delicate and expression.',
|
||||
' On the general principles of art and Mr. Kriltor rights with equal lucidity.',
|
||||
' Painting, he tells us is of a different quality to mathematics and finish in art is adding more effect.',
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||
model = model.to(torch_device)
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
num_samples = 8
|
||||
|
||||
audio = ds[:num_samples]["audio"]
|
||||
audios = [x["array"] for x in audio]
|
||||
|
||||
inputs = processor(
|
||||
audios,
|
||||
return_tensors="pt",
|
||||
sampling_rate=16_000,
|
||||
)
|
||||
inputs = inputs.to(device=torch_device)
|
||||
|
||||
gen_kwargs = {
|
||||
"return_timestamps": True,
|
||||
"no_speech_threshold": 0.6,
|
||||
"temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||
"compression_ratio_threshold": 1.35,
|
||||
"condition_on_prev_tokens": True,
|
||||
"logprob_threshold": -1.0,
|
||||
}
|
||||
|
||||
result = model.generate(**inputs, **gen_kwargs)
|
||||
decoded_all = processor.batch_decode(result.sequences, skip_special_tokens=True)
|
||||
|
||||
for i in range(num_samples):
|
||||
if isinstance(EXPECTED_TEXT[i], str):
|
||||
assert decoded_all[i] == EXPECTED_TEXT[i]
|
||||
|
||||
@slow
|
||||
def test_whisper_longform_no_speech_detection(self):
|
||||
# fmt: off
|
||||
|
||||
Reference in New Issue
Block a user