Fix kwargs handling in generate_with_fallback (#29225)
* Fix generate_with_fallback **kwargs * Change pop to get * Delete keys from kwargs to prevent overriding generation_config * Revert to passing kwargs by reference, but make a (shallow) copy * dict -> copy.copy * Add test_whisper_longform_multi_batch_beam
This commit is contained in:
@@ -755,6 +755,8 @@ class WhisperGenerationMixin:
|
|||||||
do_condition_on_prev_tokens,
|
do_condition_on_prev_tokens,
|
||||||
kwargs,
|
kwargs,
|
||||||
):
|
):
|
||||||
|
kwargs = copy.copy(kwargs)
|
||||||
|
|
||||||
# 6.6 Batch generate current chunk
|
# 6.6 Batch generate current chunk
|
||||||
seek_sequence_list = [None for _ in range(cur_bsz)]
|
seek_sequence_list = [None for _ in range(cur_bsz)]
|
||||||
seek_outputs_list = [None for _ in range(cur_bsz)]
|
seek_outputs_list = [None for _ in range(cur_bsz)]
|
||||||
@@ -769,8 +771,12 @@ class WhisperGenerationMixin:
|
|||||||
generation_config.do_sample = temperature is not None and temperature > 0.0
|
generation_config.do_sample = temperature is not None and temperature > 0.0
|
||||||
|
|
||||||
generation_config.temperature = temperature if generation_config.do_sample else 1.0
|
generation_config.temperature = temperature if generation_config.do_sample else 1.0
|
||||||
generation_config.num_beams = kwargs.pop("num_beams", 1) if not generation_config.do_sample else 1
|
generation_config.num_beams = kwargs.get("num_beams", 1) if not generation_config.do_sample else 1
|
||||||
|
|
||||||
|
generate_kwargs = copy.copy(kwargs)
|
||||||
|
for key in ["do_sample", "temperature", "num_beams"]:
|
||||||
|
if key in generate_kwargs:
|
||||||
|
del generate_kwargs[key]
|
||||||
seek_outputs = super().generate(
|
seek_outputs = super().generate(
|
||||||
segment_input,
|
segment_input,
|
||||||
generation_config,
|
generation_config,
|
||||||
@@ -779,7 +785,7 @@ class WhisperGenerationMixin:
|
|||||||
prefix_allowed_tokens_fn,
|
prefix_allowed_tokens_fn,
|
||||||
synced_gpus,
|
synced_gpus,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
**kwargs,
|
**generate_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# post-process sequence tokens and outputs to be in list form
|
# post-process sequence tokens and outputs to be in list form
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user