[Whisper] Block language/task args for English-only (#27322)
* [Whisper] Block language/task args for English-only * Update src/transformers/models/whisper/modeling_whisper.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -1841,6 +1841,22 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
generation_config.return_timestamps = False
|
generation_config.return_timestamps = False
|
||||||
|
|
||||||
|
if is_multilingual is not None:
|
||||||
|
if not hasattr(generation_config, "is_multilingual"):
|
||||||
|
raise ValueError(
|
||||||
|
"The generation config is outdated and is thus not compatible with the `is_multilingual` argument "
|
||||||
|
"to `generate`. Please update the generation config as per the instructions "
|
||||||
|
"https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
|
||||||
|
)
|
||||||
|
generation_config.is_multilingual = is_multilingual
|
||||||
|
|
||||||
|
if hasattr(generation_config, "is_multilingual") and not generation_config.is_multilingual:
|
||||||
|
if task is not None or language is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot specify `task` or `language` for an English-only model. If the model is intended to be "
|
||||||
|
"multilingual, pass `is_multilingual=True` to generate, or update the generation config."
|
||||||
|
)
|
||||||
|
|
||||||
if language is not None:
|
if language is not None:
|
||||||
if not hasattr(generation_config, "lang_to_id"):
|
if not hasattr(generation_config, "lang_to_id"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -852,6 +852,44 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
output_3 = speech_translator(filename)
|
output_3 = speech_translator(filename)
|
||||||
self.assertEqual(output_3, {"text": " Un uomo ha detto all'universo, Sir, esiste."})
|
self.assertEqual(output_3, {"text": " Un uomo ha detto all'universo, Sir, esiste."})
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_whisper_language(self):
|
||||||
|
speech_recognizer = pipeline(
|
||||||
|
task="automatic-speech-recognition",
|
||||||
|
model="openai/whisper-tiny.en",
|
||||||
|
framework="pt",
|
||||||
|
)
|
||||||
|
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||||
|
filename = ds[0]["file"]
|
||||||
|
|
||||||
|
# 1. English-only model compatible with no language argument
|
||||||
|
output = speech_recognizer(filename)
|
||||||
|
self.assertEqual(
|
||||||
|
output,
|
||||||
|
{"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. English-only Whisper does not accept the language argument
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError,
|
||||||
|
"Cannot specify `task` or `langauge` for an English-only model. If the model is intended to be multilingual, "
|
||||||
|
"pass `is_multilingual=True` to generate, or update the generation config.",
|
||||||
|
):
|
||||||
|
_ = speech_recognizer(filename, generate_kwargs={"language": "en"})
|
||||||
|
|
||||||
|
# 3. Multilingual model accepts language argument
|
||||||
|
speech_recognizer = pipeline(
|
||||||
|
task="automatic-speech-recognition",
|
||||||
|
model="openai/whisper-tiny",
|
||||||
|
framework="pt",
|
||||||
|
)
|
||||||
|
output = speech_recognizer(filename, generate_kwargs={"language": "en"})
|
||||||
|
self.assertEqual(
|
||||||
|
output,
|
||||||
|
{"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."},
|
||||||
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_torchaudio
|
@require_torchaudio
|
||||||
|
|||||||
Reference in New Issue
Block a user