Update MMS integration docs (#24311)

* Update mms.mdx

* Update mms.mdx

* Update docs/source/en/model_doc/mms.mdx

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update mms.mdx

* Update docs/source/en/model_doc/mms.mdx

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
Vineel Pratap
2023-06-19 06:49:01 -07:00
committed by GitHub
parent 5fca839fef
commit 7761b1893a

View File

@@ -30,21 +30,21 @@ for the same number of languages, as well as a language identification model for
Experiments show that our multilingual speech recognition model more than halves the word error rate of
Whisper on 54 languages of the FLEURS benchmark while being trained on a small fraction of the labeled data.*
Here are the different models open sourced in the MMS project. The models and code are originally released [here](https://github.com/facebookresearch/fairseq/tree/main/examples/mms). We have add them to the `transformers` framework, making them easier to use.
### Automatic Speech Recognition (ASR)
The ASR model checkpoints can be found here : [mms-1b-fl102](https://huggingface.co/facebook/mms-1b-fl102), [mms-1b-l1107](https://huggingface.co/facebook/mms-1b-l1107), [mms-1b-all](https://huggingface.co/facebook/mms-1b-all). For best accuracy, use the `mms-1b-all` model.
Tips:
- MMS is a speech model that accepts a float array corresponding to the raw waveform of the speech signal. The raw waveform should be pre-processed with [`Wav2Vec2FeatureExtractor`].
- MMS model was trained using connectionist temporal classification (CTC) so the model output has to be decoded using
- All ASR models accept a float array corresponding to the raw waveform of the speech signal. The raw waveform should be pre-processed with [`Wav2Vec2FeatureExtractor`].
- The models were trained using connectionist temporal classification (CTC) so the model output has to be decoded using
[`Wav2Vec2CTCTokenizer`].
- MMS can load different language adapter weights for different languages via [`~Wav2Vec2PreTrainedModel.load_adapter`]. Language adapters only consists of roughly 2 million parameters
- You can load different language adapter weights for different languages via [`~Wav2Vec2PreTrainedModel.load_adapter`]. Language adapters only consists of roughly 2 million parameters
and can therefore be efficiently loaded on the fly when needed.
Relevant checkpoints can be found under https://huggingface.co/models?other=mms.
MMS's architecture is based on the Wav2Vec2 model, so one can refer to [Wav2Vec2's documentation page](wav2vec2).
The original code can be found [here](https://github.com/facebookresearch/fairseq/tree/main/examples/mms).
## Loading
#### Loading
By default MMS loads adapter weights for English. If you want to load adapter weights of another language
make sure to specify `target_lang=<your-chosen-target-lang>` as well as `"ignore_mismatched_sizes=True`.
@@ -86,7 +86,7 @@ target_lang = "fra"
pipe = pipeline(model=model_id, model_kwargs={"target_lang": "fra", "ignore_mismatched_sizes": True})
```
## Inference
#### Inference
Next, let's look at how we can run MMS in inference and change adapter layers after having called [`~PretrainedModel.from_pretrained`]
First, we load audio data in different languages using the [Datasets](https://github.com/huggingface/datasets).
@@ -156,3 +156,81 @@ processor.tokenizer.vocab.keys()
```
to see all supported languages.
To further improve performance from ASR models, language model decoding can be used. See the documentation [here](https://huggingface.co/facebook/mms-1b-all) for further details.
### Speech Synthesis (TTS)
Individual TTS models are available for each of the 1100+ languages. The models and inference documentation can be found [here](https://huggingface.co/facebook/mms-tts).
### Language Identification (LID)
Different LID models are available based on the number of languages they can recognize - [126](https://huggingface.co/facebook/mms-lid-126), [256](https://huggingface.co/facebook/mms-lid-256), [512](https://huggingface.co/facebook/mms-lid-512), [1024](https://huggingface.co/facebook/mms-lid-1024), [2048](https://huggingface.co/facebook/mms-lid-2048), [4017](https://huggingface.co/facebook/mms-lid-4017).
#### Inference
First, we install transformers and some other libraries
```
pip install torch accelerate torchaudio datasets
pip install --upgrade transformers
````
pip install torch datasets[audio]
Next, we load a couple of audio samples via `datasets`. Make sure that the audio data is sampled to 16000 kHz.
```py
from datasets import load_dataset, Audio
# English
stream_data = load_dataset("mozilla-foundation/common_voice_13_0", "en", split="test", streaming=True)
stream_data = stream_data.cast_column("audio", Audio(sampling_rate=16000))
en_sample = next(iter(stream_data))["audio"]["array"]
# Arabic
stream_data = load_dataset("mozilla-foundation/common_voice_13_0", "ar", split="test", streaming=True)
stream_data = stream_data.cast_column("audio", Audio(sampling_rate=16000))
ar_sample = next(iter(stream_data))["audio"]["array"]
```
Next, we load the model and processor
```py
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor
import torch
model_id = "facebook/mms-lid-126"
processor = AutoFeatureExtractor.from_pretrained(model_id)
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id)
```
Now we process the audio data, pass the processed audio data to the model to classify it into a language, just like we usually do for Wav2Vec2 audio classification models such as [ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition](https://huggingface.co/harshit345/xlsr-wav2vec-speech-emotion-recognition)
```py
# English
inputs = processor(en_sample, sampling_rate=16_000, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs).logits
lang_id = torch.argmax(outputs, dim=-1)[0].item()
detected_lang = model.config.id2label[lang_id]
# 'eng'
# Arabic
inputs = processor(ar_sample, sampling_rate=16_000, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs).logits
lang_id = torch.argmax(outputs, dim=-1)[0].item()
detected_lang = model.config.id2label[lang_id]
# 'ara'
```
To see all the supported languages of a checkpoint, you can print out the language ids as follows:
```py
processor.id2label.values()
```
### Audio Pretrained Models
Pretrained models are available for two different sizes - [300M](https://huggingface.co/facebook/mms-300m) , [1Bil](https://huggingface.co/facebook/mms-1b). The architecture is based on the Wav2Vec2 model, so one can refer to [Wav2Vec2's documentation page](wav2vec2) for further details on how to finetune with models for various downstream tasks.