fix: Passing language as acronym to Whisper generate (#23141)
* add fix * address comments * remove error formatting
This commit is contained in:
@@ -414,6 +414,21 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
model.generate(input_features)
|
||||
model.generate(input_features, num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
||||
|
||||
def test_generate_language(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||||
input_features = input_dict["input_features"]
|
||||
model = WhisperForConditionalGeneration(config).to(torch_device)
|
||||
# Hack to keep the test fast and not require downloading a model with a generation_config
|
||||
model.generation_config.__setattr__("lang_to_id", {"<|en|>": 1})
|
||||
model.generation_config.__setattr__("task_to_id", {"transcribe": 2})
|
||||
|
||||
# test language code
|
||||
model.generate(input_features, language="en")
|
||||
# test tokenizer code
|
||||
model.generate(input_features, language="<|en|>")
|
||||
# test language name
|
||||
model.generate(input_features, language="English")
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user