Add ASR CTC streaming example (#15309)
* Single-epoch run * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Infinite dataset * Trainer fix + distributed benchmark * Benchmark fix * unused import * interleaved splits * interleaved splits * has_length util * Move to research projects * Leftover Sized checks * Bump min version * Unused import * Revert trainer changes Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -127,6 +127,62 @@ python -m torch.distributed.launch \
|
||||
On 8 V100 GPUs, this script should run in *ca.* 18 minutes and yield a CTC loss of **0.39** and word error rate
|
||||
of **0.36**.
|
||||
|
||||
|
||||
### Multi GPU CTC with Dataset Streaming
|
||||
|
||||
The following command shows how to use [Dataset Streaming mode](https://huggingface.co/docs/datasets/dataset_streaming.html)
|
||||
to fine-tune [XLS-R](https://huggingface.co/transformers/master/model_doc/xls_r.html)
|
||||
on [Common Voice](https://huggingface.co/datasets/common_voice) using 4 GPUs in half-precision.
|
||||
|
||||
Streaming mode imposes several constraints on training:
|
||||
1. We need to construct a tokenizer beforehand and define it via `--tokenizer_name_or_path`.
|
||||
2. `--num_train_epochs` has to be replaced by `--max_steps`. Similarly, all other epoch-based arguments have to be
|
||||
replaced by step-based ones.
|
||||
3. Full dataset shuffling on each epoch is not possible, since we don't have the whole dataset available at once.
|
||||
However, the `--shuffle_buffer_size` argument controls how many examples we can pre-download before shuffling them.
|
||||
|
||||
|
||||
```bash
|
||||
**python -m torch.distributed.launch \
|
||||
--nproc_per_node 4 run_speech_recognition_ctc_streaming.py \
|
||||
--dataset_name="common_voice" \
|
||||
--model_name_or_path="facebook/wav2vec2-xls-r-300m" \
|
||||
--tokenizer_name_or_path="anton-l/wav2vec2-tokenizer-turkish" \
|
||||
--dataset_config_name="tr" \
|
||||
--train_split_name="train+validation" \
|
||||
--eval_split_name="test" \
|
||||
--output_dir="wav2vec2-xls-r-common_voice-tr-ft" \
|
||||
--overwrite_output_dir \
|
||||
--max_steps="5000" \
|
||||
--per_device_train_batch_size="8" \
|
||||
--gradient_accumulation_steps="2" \
|
||||
--learning_rate="5e-4" \
|
||||
--warmup_steps="500" \
|
||||
--evaluation_strategy="steps" \
|
||||
--text_column_name="sentence" \
|
||||
--save_steps="500" \
|
||||
--eval_steps="500" \
|
||||
--logging_steps="1" \
|
||||
--layerdrop="0.0" \
|
||||
--eval_metrics wer cer \
|
||||
--save_total_limit="1" \
|
||||
--mask_time_prob="0.3" \
|
||||
--mask_time_length="10" \
|
||||
--mask_feature_prob="0.1" \
|
||||
--mask_feature_length="64" \
|
||||
--freeze_feature_encoder \
|
||||
--chars_to_ignore , ? . ! - \; \: \" “ % ‘ ” <20> \
|
||||
--max_duration_in_seconds="20" \
|
||||
--shuffle_buffer_size="500" \
|
||||
--fp16 \
|
||||
--push_to_hub \
|
||||
--do_train --do_eval \
|
||||
--gradient_checkpointing**
|
||||
```
|
||||
|
||||
On 4 V100 GPUs, this script should run in *ca.* 3h 31min and yield a CTC loss of **0.35** and word error rate
|
||||
of **0.29**.
|
||||
|
||||
### Examples CTC
|
||||
|
||||
The following tables present a couple of example runs on the most popular speech-recognition datasets.
|
||||
@@ -175,6 +231,7 @@ they can serve as a baseline to improve upon.
|
||||
| [Common Voice](https://huggingface.co/datasets/common_voice)| `"tr"` | [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) | 0.35 | - | 1 GPU V100 | 1h20min | [here](https://huggingface.co/patrickvonplaten/wav2vec2-common_voice-tr-demo) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-common_voice-tr-demo/blob/main/run.sh) |
|
||||
| [Common Voice](https://huggingface.co/datasets/common_voice)| `"tr"` | [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 0.31 | - | 8 GPU V100 | 1h05 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-300m-common_voice-tr-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-300m-common_voice-tr-ft/blob/main/run.sh) |
|
||||
| [Common Voice](https://huggingface.co/datasets/common_voice)| `"tr"` | [facebook/wav2vec2-xls-r-1b](https://huggingface.co/facebook/wav2vec2-xls-r-1b) | 0.21 | - | 2 GPU Titan 24 GB RAM | 15h10 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-xls-r-1b-common_voice-tr-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-1b-common_voice-tr-ft/blob/main/run.sh) |
|
||||
| [Common Voice](https://huggingface.co/datasets/common_voice)| `"tr"` in streaming mode | [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 0.29 | - | 4 GPU V100 | 3h31 | [here](https://huggingface.co/anton-l/wav2vec2-xls-r-common_voice-tr-ft-stream) | [run.sh](https://huggingface.co/anton-l/wav2vec2-xls-r-common_voice-tr-ft-stream/blob/main/run.sh) |
|
||||
|
||||
|
||||
#### Multilingual Librispeech CTC
|
||||
|
||||
Reference in New Issue
Block a user