[whisper] fix multilingual fine-tuning (#30865)
* [whisper] fix multilingual fine-tuning * config ids as well
This commit is contained in:
@@ -425,12 +425,8 @@ def main():
|
|||||||
if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual:
|
if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual:
|
||||||
# We only need to set the language and task ids in a multilingual setting
|
# We only need to set the language and task ids in a multilingual setting
|
||||||
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
|
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
|
||||||
model.generation_config.update(
|
model.generation_config.language = data_args.language
|
||||||
**{
|
model.generation_config.task = data_args.task
|
||||||
"language": data_args.language,
|
|
||||||
"task": data_args.task,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
elif data_args.language is not None:
|
elif data_args.language is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Setting language token for an English-only checkpoint is not permitted. The language argument should "
|
"Setting language token for an English-only checkpoint is not permitted. The language argument should "
|
||||||
@@ -444,6 +440,9 @@ def main():
|
|||||||
"Please use the `language` and `task` arguments instead"
|
"Please use the `language` and `task` arguments instead"
|
||||||
)
|
)
|
||||||
model.generation_config.forced_decoder_ids = model_args.forced_decoder_ids
|
model.generation_config.forced_decoder_ids = model_args.forced_decoder_ids
|
||||||
|
else:
|
||||||
|
model.generation_config.forced_decoder_ids = None
|
||||||
|
model.config.forced_decoder_ids = None
|
||||||
|
|
||||||
if model_args.suppress_tokens is not None:
|
if model_args.suppress_tokens is not None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
Reference in New Issue
Block a user