fix: Passing language as acronym to Whisper generate (#23141)
* add fix * address comments * remove error formatting
This commit is contained in:
@@ -1562,6 +1562,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
generation_config.return_timestamps = False
|
generation_config.return_timestamps = False
|
||||||
|
|
||||||
if language is not None:
|
if language is not None:
|
||||||
|
language = language.lower()
|
||||||
generation_config.language = language
|
generation_config.language = language
|
||||||
if task is not None:
|
if task is not None:
|
||||||
generation_config.task = task
|
generation_config.task = task
|
||||||
@@ -1573,10 +1574,13 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
language_token = generation_config.language
|
language_token = generation_config.language
|
||||||
elif generation_config.language in TO_LANGUAGE_CODE.keys():
|
elif generation_config.language in TO_LANGUAGE_CODE.keys():
|
||||||
language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>"
|
language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>"
|
||||||
|
elif generation_config.language in TO_LANGUAGE_CODE.values():
|
||||||
|
language_token = f"<|{generation_config.language}|>"
|
||||||
else:
|
else:
|
||||||
|
is_language_code = len(generation_config.language) == 2
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported language: {self.language}. Language should be one of:"
|
f"Unsupported language: {generation_config.language}. Language should be one of:"
|
||||||
f" {list(TO_LANGUAGE_CODE.keys()) if generation_config.language in TO_LANGUAGE_CODE.keys() else list(TO_LANGUAGE_CODE.values())}."
|
f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
|
||||||
)
|
)
|
||||||
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
|
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -414,6 +414,21 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
model.generate(input_features)
|
model.generate(input_features)
|
||||||
model.generate(input_features, num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
model.generate(input_features, num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
||||||
|
|
||||||
|
def test_generate_language(self):
|
||||||
|
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||||||
|
input_features = input_dict["input_features"]
|
||||||
|
model = WhisperForConditionalGeneration(config).to(torch_device)
|
||||||
|
# Hack to keep the test fast and not require downloading a model with a generation_config
|
||||||
|
model.generation_config.__setattr__("lang_to_id", {"<|en|>": 1})
|
||||||
|
model.generation_config.__setattr__("task_to_id", {"transcribe": 2})
|
||||||
|
|
||||||
|
# test language code
|
||||||
|
model.generate(input_features, language="en")
|
||||||
|
# test tokenizer code
|
||||||
|
model.generate(input_features, language="<|en|>")
|
||||||
|
# test language name
|
||||||
|
model.generate(input_features, language="English")
|
||||||
|
|
||||||
def test_forward_signature(self):
|
def test_forward_signature(self):
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user