[examples] update whisper fine-tuning (#29938)

* [examples] update whisper fine-tuning

* deprecate forced/suppress tokens

* item assignment

* update readme

* final fix
This commit is contained in:
Sanchit Gandhi
2024-04-26 17:06:03 +01:00
committed by GitHub
parent aafa7ce72b
commit 38b53da38a
2 changed files with 39 additions and 17 deletions

View File

@@ -368,6 +368,7 @@ python run_speech_recognition_seq2seq.py \
--dataset_name="mozilla-foundation/common_voice_11_0" \
--dataset_config_name="hi" \
--language="hindi" \
--task="transcribe" \
--train_split_name="train+validation" \
--eval_split_name="test" \
--max_steps="5000" \
@@ -384,12 +385,10 @@ python run_speech_recognition_seq2seq.py \
--save_steps="1000" \
--generation_max_length="225" \
--preprocessing_num_workers="16" \
--length_column_name="input_length" \
--max_duration_in_seconds="30" \
--text_column_name="sentence" \
--freeze_feature_encoder="False" \
--gradient_checkpointing \
--group_by_length \
--fp16 \
--overwrite_output_dir \
--do_train \
@@ -399,7 +398,8 @@ python run_speech_recognition_seq2seq.py \
```
On a single V100, training should take approximately 8 hours, with a final cross-entropy loss of **1e-4** and word error rate of **32.6%**.
If training on a different language, you should be sure to change the `language` argument. The `language` argument should be omitted for English speech recognition.
If training on a different language, you should be sure to change the `language` argument. The `language` and `task`
arguments should be omitted for English speech recognition.
#### Multi GPU Whisper Training
The following example shows how to fine-tune the [Whisper small](https://huggingface.co/openai/whisper-small) checkpoint on the Hindi subset of [Common Voice 11](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0) using 2 GPU devices in half-precision:
@@ -410,6 +410,7 @@ torchrun \
--dataset_name="mozilla-foundation/common_voice_11_0" \
--dataset_config_name="hi" \
--language="hindi" \
--task="transcribe" \
--train_split_name="train+validation" \
--eval_split_name="test" \
--max_steps="5000" \
@@ -425,12 +426,10 @@ torchrun \
--save_steps="1000" \
--generation_max_length="225" \
--preprocessing_num_workers="16" \
--length_column_name="input_length" \
--max_duration_in_seconds="30" \
--text_column_name="sentence" \
--freeze_feature_encoder="False" \
--gradient_checkpointing \
--group_by_length \
--fp16 \
--overwrite_output_dir \
--do_train \