Add Speech Seq2Seq Training script (#14792)
* start * add gradient checkpointing and feature extractor freezing * Apply suggestions from code review * up * up * up * correct * up * more changes * up * up * up * remove rst
This commit is contained in:
committed by
GitHub
parent
10fd4fa1a6
commit
1c121916f3
@@ -14,12 +14,27 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
# Automatic Speech Recognition examples
|
||||
# Automatic Speech Recognition Examples
|
||||
|
||||
## Table of Contents
|
||||
|
||||
## Connectionist Temporal Classification without Language Model (CTC w/o LM)
|
||||
- [Automatic Speech Recognition with CTC](#connectionist-temporal-classification)
|
||||
- [Single GPU example](#single-gpu)
|
||||
- [Multi GPU example](#multi-gpu)
|
||||
- [Examples](#examples)
|
||||
- [TIMIT](#timit)
|
||||
- [Librispeech](#librispeech)
|
||||
- [Common Voice](#common-voice)
|
||||
- [Multilingual Librispeech](#multilingual-librispeech)
|
||||
- [Automatic Speech Recognition with Sequence-to-Sequence](#sequence-to-sequence)
|
||||
- [Single GPU example](#single-gpu)
|
||||
- [Multi GPU example](#multi-gpu)
|
||||
- [Examples](#examples)
|
||||
- [Librispeech](#librispeech)
|
||||
|
||||
The script [`run_speech_recognition_ctc.py`](https://github.com/huggingface/transformers/blob/master/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py) can be used to fine-tune any pretrained [Connectionist Temporal Classification Model](https://huggingface.co/transformers/master/model_doc/auto.html?highlight=automodelforctc#automodelforctc) for automatic speech
|
||||
## Connectionist Temporal Classification
|
||||
|
||||
The script [`run_speech_recognition_ctc.py`](https://github.com/huggingface/transformers/blob/master/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py) can be used to fine-tune any pretrained [Connectionist Temporal Classification Model](https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForCTC) for automatic speech
|
||||
recognition on one of the [official speech recognition datasets](https://huggingface.co/datasets?task_ids=task_ids:automatic-speech-recognition) or a custom dataset.
|
||||
|
||||
Speech recognition models that have been pretrained in unsupervised fashion on audio data alone, *e.g.* [Wav2Vec2](https://huggingface.co/transformers/master/model_doc/wav2vec2.html), [HuBERT](https://huggingface.co/transformers/master/model_doc/hubert.html), [XLSR-Wav2Vec2](https://huggingface.co/transformers/master/model_doc/xlsr_wav2vec2.html), have shown to require only
|
||||
@@ -41,7 +56,7 @@ If the environment variable is not set, the training script might freeze, *i.e.*
|
||||
|
||||
---
|
||||
|
||||
### Single-GPU
|
||||
### Single GPU
|
||||
|
||||
The following command shows how to fine-tune [XLSR-Wav2Vec2](https://huggingface.co/transformers/master/model_doc/xlsr_wav2vec2.html) on [Common Voice](https://huggingface.co/datasets/common_voice) using a single GPU in half-precision.
|
||||
|
||||
@@ -75,7 +90,7 @@ python run_speech_recognition_ctc.py \
|
||||
On a single V100 GPU, this script should run in *ca.* 1 hour 20 minutes and yield a CTC loss of **0.39** and word error rate
|
||||
of **0.35**.
|
||||
|
||||
### Multi-GPU
|
||||
### Multi GPU
|
||||
|
||||
The following command shows how to fine-tune [XLSR-Wav2Vec2](https://huggingface.co/transformers/master/model_doc/xlsr_wav2vec2.html) on [Common Voice](https://huggingface.co/datasets/common_voice) using 8 GPUs in half-precision.
|
||||
|
||||
@@ -92,7 +107,6 @@ python -m torch.distributed.launch \
|
||||
--learning_rate="3e-4" \
|
||||
--warmup_steps="500" \
|
||||
--evaluation_strategy="steps" \
|
||||
--audio_column_name="path" \
|
||||
--text_column_name="sentence" \
|
||||
--save_steps="400" \
|
||||
--eval_steps="100" \
|
||||
@@ -118,6 +132,8 @@ The presented performances are by no means optimal as no hyper-parameter tuning
|
||||
they can serve as a baseline to improve upon.
|
||||
|
||||
|
||||
#### TIMIT
|
||||
|
||||
- [TIMIT](https://huggingface.co/datasets/timit_asr)
|
||||
|
||||
| Dataset | Dataset Config | Pretrained Model | Word error rate on eval | Phoneme error rate on eval | GPU setup | Training time | Fine-tuned Model & Logs | Command to reproduce |
|
||||
@@ -129,6 +145,7 @@ they can serve as a baseline to improve upon.
|
||||
| [TIMIT](https://huggingface.co/datasets/timit_asr)| - | [ntu-spml/distilhubert](https://huggingface.co/ntu-spml/distilhubert) | 0.68 | - | 1 GPU TITAN RTX | 26min | [here](https://huggingface.co/patrickvonplaten/distilhubert-timit) | [run.sh](https://huggingface.co/patrickvonplaten/distilhubert-timit/blob/main/run.sh) |
|
||||
|
||||
|
||||
#### Librispeech
|
||||
|
||||
- [Librispeech](https://huggingface.co/datasets/librispeech_asr)
|
||||
|
||||
@@ -139,7 +156,10 @@ they can serve as a baseline to improve upon.
|
||||
| [Librispeech](https://huggingface.co/datasets/librispeech_asr)| `"clean"` - `"train.100"` | [facebook/wav2vec2-large-lv60](https://huggingface.co/facebook/wav2vec2-large-lv60) | 0.042 | - | 8 GPU V100 | 1h30min | [here](https://huggingface.co/patrickvonplaten/wav2vec2-librispeech-clean-100h-demo-dist) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-librispeech-clean-100h-demo-dist/blob/main/run.sh) |
|
||||
| [Librispeech](https://huggingface.co/datasets/librispeech_asr)| `"clean"` - `"train.100"` | [facebook/wav2vec2-large-lv60](https://huggingface.co/facebook/wav2vec2-large-lv60) | 0.042 | - | 8 GPU V100 | 1h30min | [here](https://huggingface.co/patrickvonplaten/wav2vec2-librispeech-clean-100h-demo-dist) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-librispeech-clean-100h-demo-dist/blob/main/run.sh) |
|
||||
| [Librispeech](https://huggingface.co/datasets/librispeech_asr)| `"clean"` - `"train.100"` | [facebook/hubert-large-ll60k](https://huggingface.co/facebook/hubert-large-ll60k) | 0.088 | - | 8 GPU V100 | 1h30min | [here](https://huggingface.co/patrickvonplaten/hubert-librispeech-clean-100h-demo-dist) | [run.sh](https://huggingface.co/patrickvonplaten/hubert-librispeech-clean-100h-demo-dist/blob/main/run.sh) |
|
||||
| [Librispeech](https://huggingface.co/datasets/librispeech_asr)| `"clean"` - `"train.100"` | [asapp/sew-mid-100k](https://huggingface.co/asapp/sew-mid-100k) | 0.167 | | | 8 GPU V100 | 54min | [here](https://huggingface.co/patrickvonplaten/sew-mid-100k-librispeech-clean-100h-ft) | [run.sh](https://huggingface.co/patrickvonplaten/sew-mid-100k-librispeech-clean-100h-ft/blob/main/run.sh) |
|
||||
| [Librispeech](https://huggingface.co/datasets/librispeech_asr)| `"clean"` - `"train.100"` | [asapp/sew-mid-100k](https://huggingface.co/asapp/sew-mid-100k) | 0.167 | | 8 GPU V100 | 54min | [here](https://huggingface.co/patrickvonplaten/sew-mid-100k-librispeech-clean-100h-ft) | [run.sh](https://huggingface.co/patrickvonplaten/sew-mid-100k-librispeech-clean-100h-ft/blob/main/run.sh) |
|
||||
|
||||
|
||||
#### Common Voice
|
||||
|
||||
- [Common Voice](https://huggingface.co/datasets/common_voice)
|
||||
|
||||
@@ -154,9 +174,196 @@ they can serve as a baseline to improve upon.
|
||||
| [Common Voice](https://huggingface.co/datasets/common_voice)| `"tr"` | [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 0.31 | - | 8 GPU V100 | 1h05 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-300m-common_voice-tr-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-300m-common_voice-tr-ft/blob/main/run.sh) |
|
||||
| [Common Voice](https://huggingface.co/datasets/common_voice)| `"tr"` | [facebook/wav2vec2-xls-r-1b](https://huggingface.co/facebook/wav2vec2-xls-r-1b) | 0.21 | - | 2 GPU Titan 24 GB RAM | 15h10 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-xls-r-1b-common_voice-tr-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-1b-common_voice-tr-ft/blob/main/run.sh) |
|
||||
|
||||
|
||||
#### Multilingual Librispeech
|
||||
|
||||
- [Multilingual Librispeech](https://huggingface.co/datasets/multilingual_librispeech)
|
||||
|
||||
| Dataset | Dataset Config | Pretrained Model | Word error rate on eval | Phoneme error rate on eval | GPU setup | Training time | Fine-tuned Model & Logs | Command to reproduce |
|
||||
|-------|------------------------------|-------------|---------------|---------------|----------------------|-------------| -------------| ------- |
|
||||
| [Multilingual Librispeech](https://huggingface.co/datasets/multilingual_librispeech)| `"german"` | [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) | 0.13 | - | 1 GPU Titan 24 GB RAM | 15h04 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-xlsr-53-300m-mls-german-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-xlsr-53-300m-mls-german-ft/blob/main/run.sh) |
|
||||
| [Multilingual Librispeech](https://huggingface.co/datasets/multilingual_librispeech)| `"german"` | [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 0.15 | - | 1 GPU Titan 24 GB RAM | 15h04 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-300m-mls-german-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-300m-mls-german-ft/blob/main/run.sh) |
|
||||
|
||||
## Sequence to Sequence
|
||||
|
||||
The script [`run_speech_recognition_seq2seq.py`](https://github.com/huggingface/transformers/blob/master/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py) can be used to fine-tune any [Speech Sequence-to-Sequence Model](https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForSpeechSeq2Seq) for automatic speech
|
||||
recognition on one of the [official speech recognition datasets](https://huggingface.co/datasets?task_ids=task_ids:automatic-speech-recognition) or a custom dataset.
|
||||
|
||||
A very common use case is to leverage a pretrained speech [encoding model](https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModel),
|
||||
*e.g.* [Wav2Vec2](https://huggingface.co/transformers/master/model_doc/wav2vec2.html), [HuBERT](https://huggingface.co/transformers/master/model_doc/hubert.html), [XLSR-Wav2Vec2](https://huggingface.co/transformers/master/model_doc/xlsr_wav2vec2.html) with a pretrained [text decoding model](https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModel), *e.g.* [Bart](https://huggingface.co/docs/transformers/master/en/model_doc/bart#transformers.BartForCausalLM) to create a [SpeechEnocderDecoderModel](https://huggingface.co/docs/transformers/master/en/model_doc/speechencoderdecoder#speech-encoder-decoder-models).
|
||||
Consequently, the warm-started Speech-Encoder-Decoder model can be fine-tuned in
|
||||
this script.
|
||||
|
||||
As an example, let's instantiate a *Wav2Vec2-2-Bart* model with the `SpeechEnocderDecoderModel` framework:
|
||||
|
||||
First create an empty repo on `hf.co`:
|
||||
|
||||
```bash
|
||||
huggingface-cli repo create wav2vec2-2-bart-base
|
||||
git clone https://huggingface.co/<your-user-name>/wav2vec2-2-bart-base
|
||||
cd wav2vec2-2-bart-base
|
||||
```
|
||||
|
||||
Next, run the following script **inside** the just cloned repo:
|
||||
|
||||
```py
|
||||
from transformers import SpeechEncoderDecoderModel, AutoFeatureExtractor, AutoTokenizer, Wav2Vec2Processor
|
||||
|
||||
# checkpoints to leverage
|
||||
encoder_id = "facebook/wav2vec2-base"
|
||||
decoder_id = "facebook/bart-base"
|
||||
|
||||
# load and save speech-encoder-decoder model
|
||||
# set some hyper-parameters for training and evaluation
|
||||
model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_add_adapter=True, encoder_feat_proj_dropout=0.0, encoder_layerdrop=0.0, max_length=200, num_beams=5)
|
||||
model.config.decoder_start_token_id = model.decoder.config.bos_token_id
|
||||
model.config.pad_token_id = model.decoder.config.pad_token_id
|
||||
model.config.eos_token_id = model.decoder.config.eos_token_id
|
||||
model.save_pretrained("./")
|
||||
|
||||
# load and save processor
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(decoder_id)
|
||||
processor = Wav2Vec2Processor(feature_extractor, tokenizer)
|
||||
processor.save_pretrained("./")
|
||||
```
|
||||
|
||||
Finally, we can upload all files:
|
||||
```bash
|
||||
git lfs install
|
||||
git add . && git commit -m "upload model files" && git push
|
||||
```
|
||||
|
||||
and link the official `run_speech_recognition_seq2seq.py` script to the folder:
|
||||
|
||||
```bash
|
||||
ln -s $(realpath <path/to/transformers>/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py) ./
|
||||
```
|
||||
|
||||
Note that we have added a randomly initialized adapter to `wav2vec2-base` with
|
||||
`encoder_add_adapter=True` which further samples the output sequence of
|
||||
`wav2vec2-base` along the time dimension. The reason is that by default a single
|
||||
output vector of `wav2vec2-base` has a receptive field of *ca.* 25ms (*cf.* with
|
||||
section *4.2* of the [official Wav2Vec2 paper](https://arxiv.org/pdf/2006.11477.pdf)), which represents a little less a single character. BART on the other hand
|
||||
makes use of a sentence-piece tokenizer as an input processor so that a single
|
||||
hidden vector of `bart-base` represents *ca.* 4 characters. To better align
|
||||
the output of *Wav2Vec2* and *BART*'s hidden vectors for the cross-attention
|
||||
mechanism, we further subsample *Wav2Vec2*'s output by a factor of 8 by
|
||||
adding a convolution-based adapter.
|
||||
|
||||
Having warm-started the speech-encoder-decoder model `<your-user-name>/wav2vec2-2-bart`, we can now fine-tune it on speech recognition.
|
||||
|
||||
In the script [`run_speech_recognition_seq2seq`], we load the warm-started model,
|
||||
the feature extractor, and the tokenizer, process a speech recognition dataset,
|
||||
and then make use of the [`Seq2SeqTrainer`](https://huggingface.co/docs/transformers/master/en/main_classes/trainer#transformers.Seq2SeqTrainer).
|
||||
Note that it is important to also align the decoder's vocabulary with
|
||||
the speech transcriptions of the dataset. *E.g.* the [`Librispeech`](https://huggingface.co/datasets/librispeech_asr) has only captilized letters in the transcriptions,
|
||||
whereas BART was pretrained mostly on normalized text. Thus it is recommended to add
|
||||
`--do_lower_case` to the fine-tuning script when using a warm-started `SpeechEncoderDecoderModel`. The model is fine-tuned on the standard cross-entropy language modeling
|
||||
loss for sequence-to-sequence (just like *T5* or *BART* in natural language processing).
|
||||
|
||||
---
|
||||
**NOTE**
|
||||
|
||||
If you encounter problems with data preprocessing by setting `--preprocessing_num_workers` > 1,
|
||||
you might want to set the environment variable `OMP_NUM_THREADS` to 1 as follows:
|
||||
|
||||
```bash
|
||||
OMP_NUM_THREADS=1 python run_speech_recognition_ctc ...
|
||||
```
|
||||
|
||||
If the environment variable is not set, the training script might freeze, *i.e.* see: https://github.com/pytorch/audio/issues/1021#issuecomment-726915239
|
||||
|
||||
---
|
||||
|
||||
### Single GPU
|
||||
|
||||
The following command shows how to fine-tune [XLSR-Wav2Vec2](https://huggingface.co/transformers/master/model_doc/xlsr_wav2vec2.html) on [Common Voice](https://huggingface.co/datasets/common_voice) using a single GPU in half-precision.
|
||||
|
||||
```bash
|
||||
python run_speech_recognition_seq2seq.py \
|
||||
--nproc_per_node 8 run_speech_recognition_seq2seq.py \
|
||||
--dataset_name="librispeech_asr" \
|
||||
--model_name_or_path="./" \
|
||||
--dataset_config_name="clean" \
|
||||
--train_split_name="train.100" \
|
||||
--eval_split_name="validation" \
|
||||
--output_dir="./" \
|
||||
--preprocessing_num_workers="16" \
|
||||
--length_column_name="input_length" \
|
||||
--overwrite_output_dir \
|
||||
--num_train_epochs="5" \
|
||||
--per_device_train_batch_size="8" \
|
||||
--per_device_eval_batch_size="8" \
|
||||
--gradient_accumulation_steps="8" \
|
||||
--learning_rate="3e-4" \
|
||||
--warmup_steps="400" \
|
||||
--evaluation_strategy="steps" \
|
||||
--text_column_name="text" \
|
||||
--save_steps="400" \
|
||||
--eval_steps="400" \
|
||||
--logging_steps="10" \
|
||||
--save_total_limit="1" \
|
||||
--freeze_feature_extractor \
|
||||
--gradient_checkpointing \
|
||||
--fp16 \
|
||||
--group_by_length \
|
||||
--predict_with_generate \
|
||||
--generation_max_length="40" \
|
||||
--generation_num_beams="1" \
|
||||
--do_train --do_eval \
|
||||
--do_lower_case
|
||||
```
|
||||
|
||||
On a single V100 GPU, this script should run in *ca.* 5 hours and yield a
|
||||
cross-entropy loss of **0.405** and word error rate of **0.0728**.
|
||||
|
||||
### Multi GPU
|
||||
|
||||
The following command shows how to fine-tune [XLSR-Wav2Vec2](https://huggingface.co/transformers/master/model_doc/xlsr_wav2vec2.html) on [Common Voice](https://huggingface.co/datasets/common_voice) using 8 GPUs in half-precision.
|
||||
|
||||
```bash
|
||||
python -m torch.distributed.launch \
|
||||
--nproc_per_node 8 run_speech_recognition_seq2seq.py \
|
||||
--dataset_name="librispeech_asr" \
|
||||
--model_name_or_path="./" \
|
||||
--dataset_config_name="clean" \
|
||||
--train_split_name="train.100" \
|
||||
--eval_split_name="validation" \
|
||||
--output_dir="./" \
|
||||
--preprocessing_num_workers="16" \
|
||||
--length_column_name="input_length" \
|
||||
--overwrite_output_dir \
|
||||
--num_train_epochs="5" \
|
||||
--per_device_train_batch_size="8" \
|
||||
--per_device_eval_batch_size="8" \
|
||||
--gradient_accumulation_steps="1" \
|
||||
--learning_rate="3e-4" \
|
||||
--warmup_steps="400" \
|
||||
--evaluation_strategy="steps" \
|
||||
--text_column_name="text" \
|
||||
--save_steps="400" \
|
||||
--eval_steps="400" \
|
||||
--logging_steps="10" \
|
||||
--save_total_limit="1" \
|
||||
--freeze_feature_extractor \
|
||||
--gradient_checkpointing \
|
||||
--fp16 \
|
||||
--group_by_length \
|
||||
--predict_with_generate \
|
||||
--do_train --do_eval \
|
||||
--do_lower_case
|
||||
```
|
||||
|
||||
On 8 V100 GPUs, this script should run in *ca.* 45 minutes and yield a cross-entropy loss of **0.405** and word error rate of **0.0728**
|
||||
|
||||
### Examples
|
||||
|
||||
#### Librispeech
|
||||
|
||||
- [Librispeech](https://huggingface.co/datasets/librispeech_asr)
|
||||
|
||||
| Dataset | Dataset Config | Pretrained Model | Word error rate on eval | Phoneme error rate on eval | GPU setup | Training time | Fine-tuned Model & Logs | Command to reproduce |
|
||||
|-------|------------------------------|-------------|---------------|---------------|----------------------|-------------| -------------| ------- |
|
||||
| [Librispeech](https://huggingface.co/datasets/librispeech_asr)| `"clean"` - `"train.100"` | [facebook/wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) and [facebook/bart-base](https://huggingface.co/facebook/bart-base) | 0.0728 | - | 8 GPU V100 | 45min | [here](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-base) | [create_model.py](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-base/blob/main/create_model.py) & [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-base/blob/main/run_librispeech.sh) |
|
||||
| [Librispeech](https://huggingface.co/datasets/librispeech_asr)| `"clean"` - `"train.100"` | [facebook/wav2vec2-large-lv60](https://huggingface.co/facebook/wav2vec2-large-lv60) and [facebook/bart-large](https://huggingface.co/facebook/bart-large) | 0.0486 | - | 8 GPU V100 | 1h20min | [here](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-large) | [create_model.py](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-large/blob/main/create_model.py) & [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-large/blob/main/run_librispeech.sh) |
|
||||
|
||||
@@ -635,14 +635,13 @@ def main():
|
||||
|
||||
return metrics
|
||||
|
||||
# Now create a single processor
|
||||
# Now save everything to be able to create a single processor later
|
||||
if is_main_process(training_args.local_rank):
|
||||
# save feature extractor, tokenizer and config
|
||||
feature_extractor.save_pretrained(training_args.output_dir)
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
config.save_pretrained(training_args.output_dir)
|
||||
|
||||
# load processor
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(training_args.output_dir)
|
||||
except (OSError, KeyError):
|
||||
|
||||
502
examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py
Executable file
502
examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py
Executable file
@@ -0,0 +1,502 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Fine-tuning the library models for sequence to sequence speech recognition.
|
||||
"""
|
||||
# You can also adapt this script on your own sequence to sequence speech
|
||||
# recognition task. Pointers for this are left as comments.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import DatasetDict, load_dataset, load_metric
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoFeatureExtractor,
|
||||
AutoModelForSpeechSeq2Seq,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.16.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
feature_extractor_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
freeze_feature_extractor: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
dataset_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
text_column: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
audio_column_name: Optional[str] = field(
|
||||
default="audio",
|
||||
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
|
||||
)
|
||||
text_column_name: Optional[str] = field(
|
||||
default="text",
|
||||
metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
|
||||
)
|
||||
max_duration_in_seconds: Optional[float] = field(
|
||||
default=20.0,
|
||||
metadata={
|
||||
"help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`"
|
||||
},
|
||||
)
|
||||
min_duration_in_seconds: Optional[float] = field(
|
||||
default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
|
||||
)
|
||||
preprocessing_only: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to only do data preprocessing and skip training. "
|
||||
"This is especially useful when data preprocessing errors out in distributed training due to timeout. "
|
||||
"In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
|
||||
"so that the cached datasets can consequently be loaded in distributed training"
|
||||
},
|
||||
)
|
||||
train_split_name: Optional[str] = field(
|
||||
default="train",
|
||||
metadata={
|
||||
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
|
||||
},
|
||||
)
|
||||
eval_split_name: Optional[str] = field(
|
||||
default="test",
|
||||
metadata={
|
||||
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
|
||||
},
|
||||
)
|
||||
do_lower_case: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether the target text should be lower cased."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorSpeechSeq2SeqWithPadding:
|
||||
"""
|
||||
Data collator that will dynamically pad the inputs received.
|
||||
Args:
|
||||
processor ([`Wav2Vec2Processor`])
|
||||
The processor used for proccessing the data.
|
||||
decoder_start_token_id (`int`)
|
||||
The begin-of-sentence of the decoder.
|
||||
"""
|
||||
|
||||
processor: Any
|
||||
decoder_start_token_id: int
|
||||
|
||||
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||
# split inputs and labels since they have to be of different lenghts and need
|
||||
# different padding methods
|
||||
input_features = [{"input_values": feature["input_values"]} for feature in features]
|
||||
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
||||
|
||||
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
||||
|
||||
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
|
||||
|
||||
# replace padding with -100 to ignore loss correctly
|
||||
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
||||
|
||||
# if bos token is appended in previous tokenization step,
|
||||
# cut bos token here as it's append later anyways
|
||||
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
|
||||
labels = labels[:, 1:]
|
||||
|
||||
batch["labels"] = labels
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def main():
|
||||
# 1. Parse input arguments
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
|
||||
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
# 2. Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
if is_main_process(training_args.local_rank):
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
logger.info("Training/evaluation parameters %s", training_args)
|
||||
|
||||
# 3. Detecting last checkpoint and eventualy continue from last checkpoint
|
||||
last_checkpoint = None
|
||||
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(
|
||||
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
||||
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||||
)
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args.seed)
|
||||
|
||||
# 4. Load dataset
|
||||
raw_datasets = DatasetDict()
|
||||
|
||||
if training_args.do_train:
|
||||
raw_datasets["train"] = load_dataset(
|
||||
data_args.dataset_name, data_args.dataset_config_name, split=data_args.train_split_name
|
||||
)
|
||||
|
||||
if training_args.do_eval:
|
||||
raw_datasets["eval"] = load_dataset(
|
||||
data_args.dataset_name, data_args.dataset_config_name, split=data_args.eval_split_name
|
||||
)
|
||||
|
||||
if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
|
||||
raise ValueError(
|
||||
f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
|
||||
"Make sure to set `--audio_column_name` to the correct audio column - one of "
|
||||
f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
|
||||
)
|
||||
|
||||
if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
|
||||
raise ValueError(
|
||||
f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
|
||||
"Make sure to set `--text_column_name` to the correct text column - one of "
|
||||
f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
|
||||
)
|
||||
|
||||
# 5. Load pretrained model, tokenizer, and feature extractor
|
||||
#
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
if model.config.decoder_start_token_id is None:
|
||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||
|
||||
if model_args.freeze_feature_extractor:
|
||||
model.freeze_feature_extractor()
|
||||
|
||||
# 6. Resample speech dataset if necassary
|
||||
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
|
||||
if dataset_sampling_rate != feature_extractor.sampling_rate:
|
||||
raw_datasets = raw_datasets.cast_column(
|
||||
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
|
||||
)
|
||||
|
||||
# 7. Preprocessing the datasets.
|
||||
# We need to read the audio files as arrays and tokenize the targets.
|
||||
max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
|
||||
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
|
||||
audio_column_name = data_args.audio_column_name
|
||||
num_workers = data_args.preprocessing_num_workers
|
||||
text_column_name = data_args.text_column_name
|
||||
model_input_name = feature_extractor.model_input_names[0]
|
||||
do_lower_case = data_args.do_lower_case
|
||||
|
||||
if data_args.max_train_samples is not None:
|
||||
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
|
||||
|
||||
if data_args.max_eval_samples is not None:
|
||||
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_train_samples))
|
||||
|
||||
def prepare_dataset(batch):
|
||||
# process audio
|
||||
sample = batch[audio_column_name]
|
||||
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
|
||||
# process audio length
|
||||
batch[model_input_name] = inputs.input_values[0]
|
||||
batch["input_length"] = len(batch["input_values"])
|
||||
|
||||
# process targets
|
||||
input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
|
||||
batch["labels"] = tokenizer(input_str).input_ids
|
||||
return batch
|
||||
|
||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
||||
vectorized_datasets = raw_datasets.map(
|
||||
prepare_dataset,
|
||||
remove_columns=next(iter(raw_datasets.values())).column_names,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
desc="preprocess train dataset",
|
||||
)
|
||||
|
||||
# filter data that is shorter than min_input_length or longer than
|
||||
# max_input_length
|
||||
def is_audio_in_length_range(length):
|
||||
return length > min_input_length and length < max_input_length
|
||||
|
||||
vectorized_datasets = vectorized_datasets.filter(
|
||||
is_audio_in_length_range,
|
||||
num_proc=num_workers,
|
||||
input_columns=["input_length"],
|
||||
)
|
||||
|
||||
# for large datasets it is advised to run the preprocessing on a
|
||||
# single machine first with `args.preprocessing_only` since there will mostly likely
|
||||
# be a timeout when running the script in distributed mode.
|
||||
# In a second step `args.preprocessing_only` can then be set to `False` to load the
|
||||
# cached dataset
|
||||
if data_args.preprocessing_only:
|
||||
cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
|
||||
logger.info(f"Data preprocessing finished. Files cached at {cache}.")
|
||||
return
|
||||
|
||||
# 8. Load Metric
|
||||
metric = load_metric("wer")
|
||||
|
||||
def compute_metrics(pred):
|
||||
pred_ids = pred.predictions
|
||||
|
||||
pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
|
||||
|
||||
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
||||
# we do not want to group tokens when computing the metrics
|
||||
label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
|
||||
|
||||
wer = metric.compute(predictions=pred_str, references=label_str)
|
||||
|
||||
return {"wer": wer}
|
||||
|
||||
# 9. Create a single speech processor
|
||||
if is_main_process(training_args.local_rank):
|
||||
# save feature extractor, tokenizer and config
|
||||
feature_extractor.save_pretrained(training_args.output_dir)
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
config.save_pretrained(training_args.output_dir)
|
||||
|
||||
processor = AutoProcessor.from_pretrained(training_args.output_dir)
|
||||
|
||||
# 10. Define data collator
|
||||
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
|
||||
processor=processor, decoder_start_token_id=model.config.decoder_start_token_id
|
||||
)
|
||||
|
||||
# 11. Initialize Trainer
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
|
||||
eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
|
||||
tokenizer=feature_extractor,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
)
|
||||
|
||||
# 12. Training
|
||||
if training_args.do_train:
|
||||
checkpoint = None
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
checkpoint = training_args.resume_from_checkpoint
|
||||
elif last_checkpoint is not None:
|
||||
checkpoint = last_checkpoint
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
trainer.save_model() # Saves the feature extractor too for easy upload
|
||||
|
||||
metrics = train_result.metrics
|
||||
max_train_samples = (
|
||||
data_args.max_train_samples
|
||||
if data_args.max_train_samples is not None
|
||||
else len(vectorized_datasets["train"])
|
||||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(vectorized_datasets["train"]))
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
# 13. Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate(
|
||||
metric_key_prefix="eval", max_length=model.config.max_length, num_beams=model.config.num_beams
|
||||
)
|
||||
max_eval_samples = (
|
||||
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])
|
||||
)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(vectorized_datasets["eval"]))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# 14. Write Training Stats
|
||||
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "speech recognition"}
|
||||
if data_args.dataset_name is not None:
|
||||
kwargs["dataset_tags"] = data_args.dataset_name
|
||||
if data_args.dataset_config_name is not None:
|
||||
kwargs["dataset_args"] = data_args.dataset_config_name
|
||||
kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
|
||||
else:
|
||||
kwargs["dataset"] = data_args.dataset_name
|
||||
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**kwargs)
|
||||
else:
|
||||
trainer.create_model_card(**kwargs)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -59,6 +59,7 @@ if SRC_DIRS is not None:
|
||||
import run_qa as run_squad
|
||||
import run_seq2seq_qa as run_squad_seq2seq
|
||||
import run_speech_recognition_ctc
|
||||
import run_speech_recognition_seq2seq
|
||||
import run_summarization
|
||||
import run_swag
|
||||
import run_translation
|
||||
@@ -473,6 +474,39 @@ class ExamplesTests(TestCasePlus):
|
||||
result = get_results(tmp_dir)
|
||||
self.assertLess(result["eval_loss"], result["train_loss"])
|
||||
|
||||
def test_run_speech_recognition_seq2seq(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_speech_recognition_seq2seq.py
|
||||
--output_dir {tmp_dir}
|
||||
--model_name_or_path hf-internal-testing/tiny-random-speech-encoder-decoder
|
||||
--dataset_name hf-internal-testing/librispeech_asr_dummy
|
||||
--dataset_config_name clean
|
||||
--train_split_name validation
|
||||
--eval_split_name validation
|
||||
--do_train
|
||||
--do_eval
|
||||
--learning_rate 1e-4
|
||||
--per_device_train_batch_size 2
|
||||
--per_device_eval_batch_size 4
|
||||
--remove_unused_columns False
|
||||
--overwrite_output_dir True
|
||||
--preprocessing_num_workers 16
|
||||
--max_steps 10
|
||||
--seed 42
|
||||
""".split()
|
||||
|
||||
if is_cuda_and_apex_available():
|
||||
testargs.append("--fp16")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_speech_recognition_seq2seq.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertLess(result["eval_loss"], result["train_loss"])
|
||||
|
||||
def test_run_audio_classification(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
@@ -521,10 +555,10 @@ class ExamplesTests(TestCasePlus):
|
||||
--dataset_config_names clean
|
||||
--dataset_split_names validation
|
||||
--learning_rate 1e-4
|
||||
--per_device_train_batch_size 2
|
||||
--per_device_eval_batch_size 2
|
||||
--per_device_train_batch_size 4
|
||||
--per_device_eval_batch_size 4
|
||||
--preprocessing_num_workers 16
|
||||
--max_train_steps 5
|
||||
--max_train_steps 2
|
||||
--validation_split_percentage 5
|
||||
--seed 42
|
||||
""".split()
|
||||
|
||||
@@ -164,7 +164,7 @@ class AutoProcessor:
|
||||
model_type = config_class_to_model_type(type(config).__name__)
|
||||
|
||||
if getattr(config, "processor_class", None) is not None:
|
||||
processor_class = config.processor_class
|
||||
processor_class = processor_class_from_name(config.processor_class)
|
||||
return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
model_type = config_class_to_model_type(type(config).__name__)
|
||||
|
||||
@@ -905,6 +905,7 @@ class HubertModel(HubertPreTrainedModel):
|
||||
self.feature_extractor = HubertFeatureExtractor(config)
|
||||
self.feature_projection = HubertFeatureProjection(config)
|
||||
|
||||
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
|
||||
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
|
||||
|
||||
if config.do_stable_layer_norm:
|
||||
|
||||
@@ -805,6 +805,7 @@ class SEWModel(SEWPreTrainedModel):
|
||||
self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
|
||||
self.feature_dropout = nn.Dropout(config.feat_proj_dropout)
|
||||
|
||||
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
|
||||
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
|
||||
|
||||
self.encoder = SEWEncoder(config)
|
||||
|
||||
@@ -1341,6 +1341,7 @@ class SEWDModel(SEWDPreTrainedModel):
|
||||
self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
|
||||
self.feature_dropout = nn.Dropout(config.feat_proj_dropout)
|
||||
|
||||
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
|
||||
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
|
||||
|
||||
self.encoder = SEWDEncoder(config)
|
||||
|
||||
@@ -181,6 +181,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
||||
config_class = SpeechEncoderDecoderConfig
|
||||
base_model_prefix = "speech_encoder_decoder"
|
||||
main_input_name = "inputs"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -247,6 +248,11 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
||||
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
|
||||
)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
# call both encoder and decoder function on gradient checkpointing
|
||||
self.encoder._set_gradient_checkpointing(module, value=value)
|
||||
self.decoder._set_gradient_checkpointing(module, value=value)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
@@ -259,6 +265,13 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
return self.decoder.set_output_embeddings(new_embeddings)
|
||||
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the feature extractor of the speech encoder so
|
||||
that its parameters will not be updated during training.
|
||||
"""
|
||||
self.encoder.freeze_feature_extractor()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
# At the moment fast initialization is not supported for composite models
|
||||
@@ -367,7 +380,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
||||
)
|
||||
|
||||
if "config" not in kwargs_encoder:
|
||||
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
|
||||
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path, **kwargs_encoder)
|
||||
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
|
||||
logger.info(
|
||||
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
|
||||
@@ -378,7 +391,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
||||
|
||||
kwargs_encoder["config"] = encoder_config
|
||||
|
||||
encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
|
||||
encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args)
|
||||
|
||||
decoder = kwargs_decoder.pop("model", None)
|
||||
if decoder is None:
|
||||
@@ -389,7 +402,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
||||
)
|
||||
|
||||
if "config" not in kwargs_decoder:
|
||||
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
|
||||
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
|
||||
logger.info(
|
||||
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
|
||||
@@ -411,7 +424,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
||||
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
|
||||
)
|
||||
|
||||
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
|
||||
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path)
|
||||
|
||||
# instantiate config with corresponding kwargs
|
||||
config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
|
||||
|
||||
@@ -1052,6 +1052,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
|
||||
self.feature_extractor = UniSpeechFeatureExtractor(config)
|
||||
self.feature_projection = UniSpeechFeatureProjection(config)
|
||||
|
||||
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
|
||||
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
|
||||
|
||||
if config.do_stable_layer_norm:
|
||||
|
||||
@@ -1197,6 +1197,8 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||
self.feature_extractor = Wav2Vec2FeatureExtractor(config)
|
||||
self.feature_projection = Wav2Vec2FeatureProjection(config)
|
||||
|
||||
# model only needs masking vector if mask prob is > 0.0
|
||||
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
|
||||
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
|
||||
|
||||
if config.do_stable_layer_norm:
|
||||
@@ -1209,6 +1211,13 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the feature extractor so that its parameters
|
||||
will not be updated during training.
|
||||
"""
|
||||
self.feature_extractor._freeze_parameters()
|
||||
|
||||
def _mask_hidden_states(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
|
||||
@@ -19,6 +19,7 @@ import warnings
|
||||
from contextlib import contextmanager
|
||||
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from ..auto.tokenization_auto import AutoTokenizer
|
||||
from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
|
||||
from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
|
||||
@@ -44,7 +45,7 @@ class Wav2Vec2Processor:
|
||||
raise ValueError(
|
||||
f"`feature_extractor` has to be of type {Wav2Vec2FeatureExtractor.__class__}, but is {type(feature_extractor)}"
|
||||
)
|
||||
if not isinstance(tokenizer, PreTrainedTokenizer):
|
||||
if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
|
||||
raise ValueError(
|
||||
f"`tokenizer` has to be of type {PreTrainedTokenizer.__class__}, but is {type(tokenizer)}"
|
||||
)
|
||||
|
||||
@@ -1149,6 +1149,8 @@ class WavLMModel(WavLMPreTrainedModel):
|
||||
self.feature_extractor = WavLMFeatureExtractor(config)
|
||||
self.feature_projection = WavLMFeatureProjection(config)
|
||||
|
||||
# model only needs masking vector if mask prob is > 0.0
|
||||
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
|
||||
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
|
||||
|
||||
if config.do_stable_layer_norm:
|
||||
@@ -1161,6 +1163,13 @@ class WavLMModel(WavLMPreTrainedModel):
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the feature extractor so that its parameters
|
||||
will not be updated during training.
|
||||
"""
|
||||
self.feature_extractor._freeze_parameters()
|
||||
|
||||
def _mask_hidden_states(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
|
||||
Reference in New Issue
Block a user