[examples] update whisper fine-tuning (#29938)
* [examples] update whisper fine-tuning * deprecate forced/suppress tokens * item assignment * update readme * final fix
This commit is contained in:
@@ -368,6 +368,7 @@ python run_speech_recognition_seq2seq.py \
|
|||||||
--dataset_name="mozilla-foundation/common_voice_11_0" \
|
--dataset_name="mozilla-foundation/common_voice_11_0" \
|
||||||
--dataset_config_name="hi" \
|
--dataset_config_name="hi" \
|
||||||
--language="hindi" \
|
--language="hindi" \
|
||||||
|
--task="transcribe" \
|
||||||
--train_split_name="train+validation" \
|
--train_split_name="train+validation" \
|
||||||
--eval_split_name="test" \
|
--eval_split_name="test" \
|
||||||
--max_steps="5000" \
|
--max_steps="5000" \
|
||||||
@@ -384,12 +385,10 @@ python run_speech_recognition_seq2seq.py \
|
|||||||
--save_steps="1000" \
|
--save_steps="1000" \
|
||||||
--generation_max_length="225" \
|
--generation_max_length="225" \
|
||||||
--preprocessing_num_workers="16" \
|
--preprocessing_num_workers="16" \
|
||||||
--length_column_name="input_length" \
|
|
||||||
--max_duration_in_seconds="30" \
|
--max_duration_in_seconds="30" \
|
||||||
--text_column_name="sentence" \
|
--text_column_name="sentence" \
|
||||||
--freeze_feature_encoder="False" \
|
--freeze_feature_encoder="False" \
|
||||||
--gradient_checkpointing \
|
--gradient_checkpointing \
|
||||||
--group_by_length \
|
|
||||||
--fp16 \
|
--fp16 \
|
||||||
--overwrite_output_dir \
|
--overwrite_output_dir \
|
||||||
--do_train \
|
--do_train \
|
||||||
@@ -399,7 +398,8 @@ python run_speech_recognition_seq2seq.py \
|
|||||||
```
|
```
|
||||||
On a single V100, training should take approximately 8 hours, with a final cross-entropy loss of **1e-4** and word error rate of **32.6%**.
|
On a single V100, training should take approximately 8 hours, with a final cross-entropy loss of **1e-4** and word error rate of **32.6%**.
|
||||||
|
|
||||||
If training on a different language, you should be sure to change the `language` argument. The `language` argument should be omitted for English speech recognition.
|
If training on a different language, you should be sure to change the `language` argument. The `language` and `task`
|
||||||
|
arguments should be omitted for English speech recognition.
|
||||||
|
|
||||||
#### Multi GPU Whisper Training
|
#### Multi GPU Whisper Training
|
||||||
The following example shows how to fine-tune the [Whisper small](https://huggingface.co/openai/whisper-small) checkpoint on the Hindi subset of [Common Voice 11](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0) using 2 GPU devices in half-precision:
|
The following example shows how to fine-tune the [Whisper small](https://huggingface.co/openai/whisper-small) checkpoint on the Hindi subset of [Common Voice 11](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0) using 2 GPU devices in half-precision:
|
||||||
@@ -410,6 +410,7 @@ torchrun \
|
|||||||
--dataset_name="mozilla-foundation/common_voice_11_0" \
|
--dataset_name="mozilla-foundation/common_voice_11_0" \
|
||||||
--dataset_config_name="hi" \
|
--dataset_config_name="hi" \
|
||||||
--language="hindi" \
|
--language="hindi" \
|
||||||
|
--task="transcribe" \
|
||||||
--train_split_name="train+validation" \
|
--train_split_name="train+validation" \
|
||||||
--eval_split_name="test" \
|
--eval_split_name="test" \
|
||||||
--max_steps="5000" \
|
--max_steps="5000" \
|
||||||
@@ -425,12 +426,10 @@ torchrun \
|
|||||||
--save_steps="1000" \
|
--save_steps="1000" \
|
||||||
--generation_max_length="225" \
|
--generation_max_length="225" \
|
||||||
--preprocessing_num_workers="16" \
|
--preprocessing_num_workers="16" \
|
||||||
--length_column_name="input_length" \
|
|
||||||
--max_duration_in_seconds="30" \
|
--max_duration_in_seconds="30" \
|
||||||
--text_column_name="sentence" \
|
--text_column_name="sentence" \
|
||||||
--freeze_feature_encoder="False" \
|
--freeze_feature_encoder="False" \
|
||||||
--gradient_checkpointing \
|
--gradient_checkpointing \
|
||||||
--group_by_length \
|
|
||||||
--fp16 \
|
--fp16 \
|
||||||
--overwrite_output_dir \
|
--overwrite_output_dir \
|
||||||
--do_train \
|
--do_train \
|
||||||
|
|||||||
@@ -119,16 +119,15 @@ class ModelArguments:
|
|||||||
)
|
)
|
||||||
forced_decoder_ids: List[List[int]] = field(
|
forced_decoder_ids: List[List[int]] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={"help": "Deprecated. Please use the `language` and `task` arguments instead."},
|
||||||
"help": (
|
|
||||||
"A list of pairs of integers which indicates a mapping from generation indices to token indices "
|
|
||||||
"that will be forced before sampling. For example, [[0, 123]] means the first generated token "
|
|
||||||
"will always be a token of index 123."
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
suppress_tokens: List[int] = field(
|
suppress_tokens: List[int] = field(
|
||||||
default=None, metadata={"help": "A list of tokens that will be suppressed at generation."}
|
default=None, metadata={
|
||||||
|
"help": (
|
||||||
|
"Deprecated. The use of `suppress_tokens` should not be required for the majority of fine-tuning examples."
|
||||||
|
"Should you need to use `suppress_tokens`, please manually update them in the fine-tuning script directly."
|
||||||
|
)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
apply_spec_augment: bool = field(
|
apply_spec_augment: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
@@ -400,8 +399,6 @@ def main():
|
|||||||
trust_remote_code=model_args.trust_remote_code,
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
|
|
||||||
|
|
||||||
# SpecAugment for whisper models
|
# SpecAugment for whisper models
|
||||||
if getattr(config, "model_type", None) == "whisper":
|
if getattr(config, "model_type", None) == "whisper":
|
||||||
config.update({"apply_spec_augment": model_args.apply_spec_augment})
|
config.update({"apply_spec_augment": model_args.apply_spec_augment})
|
||||||
@@ -440,9 +437,35 @@ def main():
|
|||||||
model.freeze_encoder()
|
model.freeze_encoder()
|
||||||
model.model.encoder.gradient_checkpointing = False
|
model.model.encoder.gradient_checkpointing = False
|
||||||
|
|
||||||
if data_args.language is not None:
|
if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual:
|
||||||
# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
|
# We only need to set the language and task ids in a multilingual setting
|
||||||
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
|
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
|
||||||
|
model.generation_config.update(
|
||||||
|
**{
|
||||||
|
"language": data_args.language,
|
||||||
|
"task": data_args.task,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif data_args.language is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Setting language token for an English-only checkpoint is not permitted. The language argument should "
|
||||||
|
"only be set for multilingual checkpoints."
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO (Sanchit): deprecate these arguments in v4.41
|
||||||
|
if model_args.forced_decoder_ids is not None:
|
||||||
|
logger.warning(
|
||||||
|
"The use of `forced_decoder_ids` is deprecated and will be removed in v4.41."
|
||||||
|
"Please use the `language` and `task` arguments instead"
|
||||||
|
)
|
||||||
|
model.generation_config.forced_decoder_ids = model_args.forced_decoder_ids
|
||||||
|
|
||||||
|
if model_args.suppress_tokens is not None:
|
||||||
|
logger.warning(
|
||||||
|
"The use of `suppress_tokens` is deprecated and will be removed in v4.41."
|
||||||
|
"Should you need `suppress_tokens`, please manually set them in the fine-tuning script."
|
||||||
|
)
|
||||||
|
model.generation_config.suppress_tokens = model_args.suppress_tokens
|
||||||
|
|
||||||
# 6. Resample speech dataset if necessary
|
# 6. Resample speech dataset if necessary
|
||||||
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
|
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
|
||||||
|
|||||||
Reference in New Issue
Block a user