From 5c6257d1fcef7b43f8654df808fc8b2be4ec402e Mon Sep 17 00:00:00 2001 From: benniekiss <63211101+benniekiss@users.noreply.github.com> Date: Thu, 12 Sep 2024 12:48:36 -0400 Subject: [PATCH] [whisper] Clarify error message when setting max_new_tokens (#33324) * clarify error message when setting max_new_tokens * sync error message in test_generate_with_prompt_ids_max_length * there is no self --- src/transformers/models/whisper/generation_whisper.py | 4 ++-- tests/models/whisper/test_modeling_whisper.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index c67aa0cd01..91812155c5 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1698,8 +1698,8 @@ class WhisperGenerationMixin: max_new_tokens = generation_config.max_new_tokens if generation_config.max_new_tokens is not None else 0 if max_new_tokens + decoder_input_ids.shape[-1] > self.config.max_target_positions: raise ValueError( - f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` " - f"is {max_new_tokens}. Thus, the combined length of " + f"The length of `decoder_input_ids`, including special start tokens, prompt tokens, and previous tokens, is {decoder_input_ids.shape[-1]}, " + f" and `max_new_tokens` is {max_new_tokens}. Thus, the combined length of " f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the " f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. " "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, " diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index e503937458..09be23a0d3 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1349,8 +1349,8 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi with self.assertRaisesRegex( ValueError, - f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` " - f"is {max_new_tokens}. Thus, the combined length of " + f"The length of `decoder_input_ids`, including special start tokens, prompt tokens, and previous tokens, is {decoder_input_ids.shape[-1]}, " + f" and `max_new_tokens` is {max_new_tokens}. Thus, the combined length of " f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the " f"`max_target_positions` of the Whisper model: {config.max_target_positions}. " "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "