Add MMS CTC Fine-Tuning (#24281)
* Add mms ctc fine tuning * make style * More fixes that are needed * make fix-copies * make draft for README * add new file * move to new file * make style * make style * add quick test * make style * make style
This commit is contained in:
committed by
GitHub
parent
0c3fdccf2f
commit
1609a436ec
@@ -26,6 +26,10 @@ limitations under the License.
|
||||
- [Librispeech](#librispeech-ctc)
|
||||
- [Common Voice](#common-voice-ctc)
|
||||
- [Multilingual Librispeech](#multilingual-librispeech-ctc)
|
||||
- [Automatic Speech Recognition with CTC and Adapter Layers](#connectionist-temporal-classification-with-adapters)
|
||||
- [Massive Multilingual Speech (MMS)](#mms-model)
|
||||
- [Examples](#examples-ctc-adapter)
|
||||
- [Common Voice](#common-voice-ctc-adapter)
|
||||
- [Automatic Speech Recognition with Sequence-to-Sequence](#sequence-to-sequence)
|
||||
- [Whisper Model](#whisper-model)
|
||||
- [Speech-Encoder-Decoder Model](#warm-started-speech-encoder-decoder-model)
|
||||
@@ -243,6 +247,111 @@ they can serve as a baseline to improve upon.
|
||||
| [Multilingual Librispeech](https://huggingface.co/datasets/multilingual_librispeech)| `"german"` | [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) | 0.13 | - | 1 GPU Titan 24 GB RAM | 15h04 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-xlsr-53-300m-mls-german-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-xlsr-53-300m-mls-german-ft/blob/main/run.sh) |
|
||||
| [Multilingual Librispeech](https://huggingface.co/datasets/multilingual_librispeech)| `"german"` | [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 0.15 | - | 1 GPU Titan 24 GB RAM | 15h04 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-300m-mls-german-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-300m-mls-german-ft/blob/main/run.sh) |
|
||||
|
||||
## Connectionist Temporal Classification With Adapters
|
||||
|
||||
The script [`run_speech_recognition_ctc_adapter.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/speech-recognition/run_speech_recognition_ctc_adapter.py) can be used to fine-tune adapter layers for [Wav2Vec2-like models like MMS](https://huggingface.co/docs/transformers/main/en/model_doc/mms) for automatic speech recognition.
|
||||
|
||||
### MMS Model
|
||||
|
||||
The [Massive Multilingual Speech (MMS) model](https://huggingface.co/facebook/mms-1b-all) has been pre-trained and fine-tuned
|
||||
on 1000+ languages. The model makes use of adapter attention layers to fine-tune only a small part
|
||||
of the model on a specific language. The model already comes with fine-tuned adapter layers for 1000+ languages and
|
||||
can be used for inference for 1000+ languages out of the box.
|
||||
|
||||
However, for improved performance or more specific use cases one can re-initialize the adapter weights, freeze all
|
||||
other weights and fine-tune them on a specific dataset as shown in the [example below](#examples-ctc-adapter).
|
||||
|
||||
Note that the adapter weights include low dimensional linear layers for every attention block as well as the final language
|
||||
model head layers.
|
||||
|
||||
### Examples CTC Adapter
|
||||
|
||||
In the following we will look at how one can fine-tune adapter weights for any of the
|
||||
[MMS CTC checkpoints](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition&other=mms&sort=downloads) in less than 1 hour.
|
||||
|
||||
#### Common Voice CTC Adapter
|
||||
|
||||
As in the examples [above](#examples-ctc), we fine-tune on Common Voice's 6 dataset in Turkish as an example.
|
||||
Contrary to [`run_speech_recognition_ctc.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py) before there is a `--target_language` which has to be defined to state for which
|
||||
language or concept the adapter layers shall be trained. The adapter weights will then
|
||||
accordingly be called `adapter.{<target_language}.safetensors`.
|
||||
|
||||
Let's run an example script. Make sure to be logged in so that your model can be directly uploaded to the Hub.
|
||||
```
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
Now, let's run an example and upload it to the Hub under `wav2vec2-common_voice-tr-mms-demo`.
|
||||
|
||||
```sh
|
||||
python run_speech_recognition_ctc.py \
|
||||
--dataset_name="common_voice" \
|
||||
--model_name_or_path="facebook/mms-1b-all" \
|
||||
--dataset_config_name="tr" \
|
||||
--output_dir="./wav2vec2-common_voice-tr-mms-demo" \
|
||||
--num_train_epochs="4" \
|
||||
--per_device_train_batch_size="32" \
|
||||
--learning_rate="1e-3" \
|
||||
--warmup_steps="100" \
|
||||
--evaluation_strategy="steps" \
|
||||
--text_column_name="sentence" \
|
||||
--length_column_name="input_length" \
|
||||
--save_steps="200" \
|
||||
--eval_steps="100" \
|
||||
--save_total_limit="3" \
|
||||
--target_language="tur" \
|
||||
--gradient_checkpointing \
|
||||
--chars_to_ignore , ? . ! - \; \: \" “ % ‘ ” <20> \
|
||||
--fp16 \
|
||||
--group_by_length \
|
||||
--do_train --do_eval \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
This should take less than 10 minutes on most GPUs and you should very quickly get word error rates
|
||||
below 27%.
|
||||
|
||||
For an example run, you can have a look at [`patrickvonplaten/wav2vec2-common_voice-tr-mms-demo`](https://huggingface.co/patrickvonplaten/wav2vec2-common_voice-tr-mms-demo).
|
||||
|
||||
|
||||
If you'd like to train another adapter model with the same base model, you can simply re-use the same `--output_dir`,
|
||||
but make sure to pass the `--output_dir` folder also to `--tokenizer_name_or_path` so that the vocabulary is not
|
||||
overwritten but **extended**. Assuming you would like to train adapter weights on Swedish in addition to Turkish and save
|
||||
the adapter weights in the same model repo, you can run:
|
||||
|
||||
```sh
|
||||
python run_speech_recognition_ctc.py \
|
||||
--dataset_name="common_voice" \
|
||||
--model_name_or_path="facebook/mms-1b-all" \
|
||||
--dataset_config_name="sw" \
|
||||
--output_dir="./wav2vec2-common_voice-tr-mms-demo" \
|
||||
--tokenizer_name_or_path="./wav2vec2-common_voice-tr-mms-demo" \
|
||||
--num_train_epochs="4" \
|
||||
--per_device_train_batch_size="32" \
|
||||
--learning_rate="1e-3" \
|
||||
--warmup_steps="100" \
|
||||
--evaluation_strategy="steps" \
|
||||
--text_column_name="sentence" \
|
||||
--length_column_name="input_length" \
|
||||
--save_steps="200" \
|
||||
--eval_steps="100" \
|
||||
--save_total_limit="3" \
|
||||
--target_language="swe" \
|
||||
--gradient_checkpointing \
|
||||
--chars_to_ignore , ? . ! - \; \: \" “ % ‘ ” <20> \
|
||||
--fp16 \
|
||||
--group_by_length \
|
||||
--do_train --do_eval \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
Now you should have both `adapter.tur.safetensors` and `adapter.swe.safetensors` in the model repo
|
||||
and you can load the respective language with:
|
||||
```py
|
||||
model.load_adapter("tur") # or "swe"
|
||||
```
|
||||
respectively.
|
||||
|
||||
## Sequence to Sequence
|
||||
|
||||
The script [`run_speech_recognition_seq2seq.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py) can be used to fine-tune any [Speech Sequence-to-Sequence Model](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForSpeechSeq2Seq) for automatic speech
|
||||
|
||||
Reference in New Issue
Block a user