From 38b53da38af231b0af967d15ca29c52470e402d5 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Fri, 26 Apr 2024 17:06:03 +0100 Subject: [PATCH] [examples] update whisper fine-tuning (#29938) * [examples] update whisper fine-tuning * deprecate forced/suppress tokens * item assignment * update readme * final fix --- examples/pytorch/speech-recognition/README.md | 9 ++-- .../run_speech_recognition_seq2seq.py | 47 ++++++++++++++----- 2 files changed, 39 insertions(+), 17 deletions(-) diff --git a/examples/pytorch/speech-recognition/README.md b/examples/pytorch/speech-recognition/README.md index b9cab9513b..4990219f42 100644 --- a/examples/pytorch/speech-recognition/README.md +++ b/examples/pytorch/speech-recognition/README.md @@ -368,6 +368,7 @@ python run_speech_recognition_seq2seq.py \ --dataset_name="mozilla-foundation/common_voice_11_0" \ --dataset_config_name="hi" \ --language="hindi" \ + --task="transcribe" \ --train_split_name="train+validation" \ --eval_split_name="test" \ --max_steps="5000" \ @@ -384,12 +385,10 @@ python run_speech_recognition_seq2seq.py \ --save_steps="1000" \ --generation_max_length="225" \ --preprocessing_num_workers="16" \ - --length_column_name="input_length" \ --max_duration_in_seconds="30" \ --text_column_name="sentence" \ --freeze_feature_encoder="False" \ --gradient_checkpointing \ - --group_by_length \ --fp16 \ --overwrite_output_dir \ --do_train \ @@ -399,7 +398,8 @@ python run_speech_recognition_seq2seq.py \ ``` On a single V100, training should take approximately 8 hours, with a final cross-entropy loss of **1e-4** and word error rate of **32.6%**. -If training on a different language, you should be sure to change the `language` argument. The `language` argument should be omitted for English speech recognition. +If training on a different language, you should be sure to change the `language` argument. The `language` and `task` +arguments should be omitted for English speech recognition. #### Multi GPU Whisper Training The following example shows how to fine-tune the [Whisper small](https://huggingface.co/openai/whisper-small) checkpoint on the Hindi subset of [Common Voice 11](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0) using 2 GPU devices in half-precision: @@ -410,6 +410,7 @@ torchrun \ --dataset_name="mozilla-foundation/common_voice_11_0" \ --dataset_config_name="hi" \ --language="hindi" \ + --task="transcribe" \ --train_split_name="train+validation" \ --eval_split_name="test" \ --max_steps="5000" \ @@ -425,12 +426,10 @@ torchrun \ --save_steps="1000" \ --generation_max_length="225" \ --preprocessing_num_workers="16" \ - --length_column_name="input_length" \ --max_duration_in_seconds="30" \ --text_column_name="sentence" \ --freeze_feature_encoder="False" \ --gradient_checkpointing \ - --group_by_length \ --fp16 \ --overwrite_output_dir \ --do_train \ diff --git a/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py b/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py index 3a596e2cb7..f352954d80 100755 --- a/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py +++ b/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py @@ -119,16 +119,15 @@ class ModelArguments: ) forced_decoder_ids: List[List[int]] = field( default=None, - metadata={ - "help": ( - "A list of pairs of integers which indicates a mapping from generation indices to token indices " - "that will be forced before sampling. For example, [[0, 123]] means the first generated token " - "will always be a token of index 123." - ) - }, + metadata={"help": "Deprecated. Please use the `language` and `task` arguments instead."}, ) suppress_tokens: List[int] = field( - default=None, metadata={"help": "A list of tokens that will be suppressed at generation."} + default=None, metadata={ + "help": ( + "Deprecated. The use of `suppress_tokens` should not be required for the majority of fine-tuning examples." + "Should you need to use `suppress_tokens`, please manually update them in the fine-tuning script directly." + ) + }, ) apply_spec_augment: bool = field( default=False, @@ -400,8 +399,6 @@ def main(): trust_remote_code=model_args.trust_remote_code, ) - config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens}) - # SpecAugment for whisper models if getattr(config, "model_type", None) == "whisper": config.update({"apply_spec_augment": model_args.apply_spec_augment}) @@ -440,9 +437,35 @@ def main(): model.freeze_encoder() model.model.encoder.gradient_checkpointing = False - if data_args.language is not None: - # We only need to set the task id when the language is specified (i.e. in a multilingual setting) + if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual: + # We only need to set the language and task ids in a multilingual setting tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task) + model.generation_config.update( + **{ + "language": data_args.language, + "task": data_args.task, + } + ) + elif data_args.language is not None: + raise ValueError( + "Setting language token for an English-only checkpoint is not permitted. The language argument should " + "only be set for multilingual checkpoints." + ) + + # TODO (Sanchit): deprecate these arguments in v4.41 + if model_args.forced_decoder_ids is not None: + logger.warning( + "The use of `forced_decoder_ids` is deprecated and will be removed in v4.41." + "Please use the `language` and `task` arguments instead" + ) + model.generation_config.forced_decoder_ids = model_args.forced_decoder_ids + + if model_args.suppress_tokens is not None: + logger.warning( + "The use of `suppress_tokens` is deprecated and will be removed in v4.41." + "Should you need `suppress_tokens`, please manually set them in the fine-tuning script." + ) + model.generation_config.suppress_tokens = model_args.suppress_tokens # 6. Resample speech dataset if necessary dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate