From 1c121916f3adee769eb43d4656b621be60427bbd Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 28 Dec 2021 10:20:51 +0100 Subject: [PATCH] Add Speech Seq2Seq Training script (#14792) * start * add gradient checkpointing and feature extractor freezing * Apply suggestions from code review * up * up * up * correct * up * more changes * up * up * up * remove rst --- examples/pytorch/speech-recognition/README.md | 221 +++++++- .../run_speech_recognition_ctc.py | 3 +- .../run_speech_recognition_seq2seq.py | 502 ++++++++++++++++++ examples/pytorch/test_examples.py | 40 +- .../models/auto/processing_auto.py | 2 +- .../models/hubert/modeling_hubert.py | 3 +- src/transformers/models/sew/modeling_sew.py | 3 +- .../models/sew_d/modeling_sew_d.py | 3 +- .../modeling_speech_encoder_decoder.py | 21 +- .../models/unispeech/modeling_unispeech.py | 3 +- .../models/wav2vec2/modeling_wav2vec2.py | 11 +- .../models/wav2vec2/processing_wav2vec2.py | 3 +- .../models/wavlm/modeling_wavlm.py | 11 +- 13 files changed, 802 insertions(+), 24 deletions(-) create mode 100755 examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py diff --git a/examples/pytorch/speech-recognition/README.md b/examples/pytorch/speech-recognition/README.md index e209000eb4..417ea913f7 100644 --- a/examples/pytorch/speech-recognition/README.md +++ b/examples/pytorch/speech-recognition/README.md @@ -14,12 +14,27 @@ See the License for the specific language governing permissions and limitations under the License. --> -# Automatic Speech Recognition examples +# Automatic Speech Recognition Examples +## Table of Contents -## Connectionist Temporal Classification without Language Model (CTC w/o LM) +- [Automatic Speech Recognition with CTC](#connectionist-temporal-classification) + - [Single GPU example](#single-gpu) + - [Multi GPU example](#multi-gpu) + - [Examples](#examples) + - [TIMIT](#timit) + - [Librispeech](#librispeech) + - [Common Voice](#common-voice) + - [Multilingual Librispeech](#multilingual-librispeech) +- [Automatic Speech Recognition with Sequence-to-Sequence](#sequence-to-sequence) + - [Single GPU example](#single-gpu) + - [Multi GPU example](#multi-gpu) + - [Examples](#examples) + - [Librispeech](#librispeech) -The script [`run_speech_recognition_ctc.py`](https://github.com/huggingface/transformers/blob/master/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py) can be used to fine-tune any pretrained [Connectionist Temporal Classification Model](https://huggingface.co/transformers/master/model_doc/auto.html?highlight=automodelforctc#automodelforctc) for automatic speech +## Connectionist Temporal Classification + +The script [`run_speech_recognition_ctc.py`](https://github.com/huggingface/transformers/blob/master/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py) can be used to fine-tune any pretrained [Connectionist Temporal Classification Model](https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForCTC) for automatic speech recognition on one of the [official speech recognition datasets](https://huggingface.co/datasets?task_ids=task_ids:automatic-speech-recognition) or a custom dataset. Speech recognition models that have been pretrained in unsupervised fashion on audio data alone, *e.g.* [Wav2Vec2](https://huggingface.co/transformers/master/model_doc/wav2vec2.html), [HuBERT](https://huggingface.co/transformers/master/model_doc/hubert.html), [XLSR-Wav2Vec2](https://huggingface.co/transformers/master/model_doc/xlsr_wav2vec2.html), have shown to require only @@ -41,7 +56,7 @@ If the environment variable is not set, the training script might freeze, *i.e.* --- -### Single-GPU +### Single GPU The following command shows how to fine-tune [XLSR-Wav2Vec2](https://huggingface.co/transformers/master/model_doc/xlsr_wav2vec2.html) on [Common Voice](https://huggingface.co/datasets/common_voice) using a single GPU in half-precision. @@ -75,7 +90,7 @@ python run_speech_recognition_ctc.py \ On a single V100 GPU, this script should run in *ca.* 1 hour 20 minutes and yield a CTC loss of **0.39** and word error rate of **0.35**. -### Multi-GPU +### Multi GPU The following command shows how to fine-tune [XLSR-Wav2Vec2](https://huggingface.co/transformers/master/model_doc/xlsr_wav2vec2.html) on [Common Voice](https://huggingface.co/datasets/common_voice) using 8 GPUs in half-precision. @@ -92,7 +107,6 @@ python -m torch.distributed.launch \ --learning_rate="3e-4" \ --warmup_steps="500" \ --evaluation_strategy="steps" \ - --audio_column_name="path" \ --text_column_name="sentence" \ --save_steps="400" \ --eval_steps="100" \ @@ -118,6 +132,8 @@ The presented performances are by no means optimal as no hyper-parameter tuning they can serve as a baseline to improve upon. +#### TIMIT + - [TIMIT](https://huggingface.co/datasets/timit_asr) | Dataset | Dataset Config | Pretrained Model | Word error rate on eval | Phoneme error rate on eval | GPU setup | Training time | Fine-tuned Model & Logs | Command to reproduce | @@ -129,6 +145,7 @@ they can serve as a baseline to improve upon. | [TIMIT](https://huggingface.co/datasets/timit_asr)| - | [ntu-spml/distilhubert](https://huggingface.co/ntu-spml/distilhubert) | 0.68 | - | 1 GPU TITAN RTX | 26min | [here](https://huggingface.co/patrickvonplaten/distilhubert-timit) | [run.sh](https://huggingface.co/patrickvonplaten/distilhubert-timit/blob/main/run.sh) | +#### Librispeech - [Librispeech](https://huggingface.co/datasets/librispeech_asr) @@ -139,7 +156,10 @@ they can serve as a baseline to improve upon. | [Librispeech](https://huggingface.co/datasets/librispeech_asr)| `"clean"` - `"train.100"` | [facebook/wav2vec2-large-lv60](https://huggingface.co/facebook/wav2vec2-large-lv60) | 0.042 | - | 8 GPU V100 | 1h30min | [here](https://huggingface.co/patrickvonplaten/wav2vec2-librispeech-clean-100h-demo-dist) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-librispeech-clean-100h-demo-dist/blob/main/run.sh) | | [Librispeech](https://huggingface.co/datasets/librispeech_asr)| `"clean"` - `"train.100"` | [facebook/wav2vec2-large-lv60](https://huggingface.co/facebook/wav2vec2-large-lv60) | 0.042 | - | 8 GPU V100 | 1h30min | [here](https://huggingface.co/patrickvonplaten/wav2vec2-librispeech-clean-100h-demo-dist) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-librispeech-clean-100h-demo-dist/blob/main/run.sh) | | [Librispeech](https://huggingface.co/datasets/librispeech_asr)| `"clean"` - `"train.100"` | [facebook/hubert-large-ll60k](https://huggingface.co/facebook/hubert-large-ll60k) | 0.088 | - | 8 GPU V100 | 1h30min | [here](https://huggingface.co/patrickvonplaten/hubert-librispeech-clean-100h-demo-dist) | [run.sh](https://huggingface.co/patrickvonplaten/hubert-librispeech-clean-100h-demo-dist/blob/main/run.sh) | -| [Librispeech](https://huggingface.co/datasets/librispeech_asr)| `"clean"` - `"train.100"` | [asapp/sew-mid-100k](https://huggingface.co/asapp/sew-mid-100k) | 0.167 | | | 8 GPU V100 | 54min | [here](https://huggingface.co/patrickvonplaten/sew-mid-100k-librispeech-clean-100h-ft) | [run.sh](https://huggingface.co/patrickvonplaten/sew-mid-100k-librispeech-clean-100h-ft/blob/main/run.sh) | +| [Librispeech](https://huggingface.co/datasets/librispeech_asr)| `"clean"` - `"train.100"` | [asapp/sew-mid-100k](https://huggingface.co/asapp/sew-mid-100k) | 0.167 | | 8 GPU V100 | 54min | [here](https://huggingface.co/patrickvonplaten/sew-mid-100k-librispeech-clean-100h-ft) | [run.sh](https://huggingface.co/patrickvonplaten/sew-mid-100k-librispeech-clean-100h-ft/blob/main/run.sh) | + + +#### Common Voice - [Common Voice](https://huggingface.co/datasets/common_voice) @@ -154,9 +174,196 @@ they can serve as a baseline to improve upon. | [Common Voice](https://huggingface.co/datasets/common_voice)| `"tr"` | [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 0.31 | - | 8 GPU V100 | 1h05 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-300m-common_voice-tr-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-300m-common_voice-tr-ft/blob/main/run.sh) | | [Common Voice](https://huggingface.co/datasets/common_voice)| `"tr"` | [facebook/wav2vec2-xls-r-1b](https://huggingface.co/facebook/wav2vec2-xls-r-1b) | 0.21 | - | 2 GPU Titan 24 GB RAM | 15h10 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-xls-r-1b-common_voice-tr-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-1b-common_voice-tr-ft/blob/main/run.sh) | + +#### Multilingual Librispeech + - [Multilingual Librispeech](https://huggingface.co/datasets/multilingual_librispeech) | Dataset | Dataset Config | Pretrained Model | Word error rate on eval | Phoneme error rate on eval | GPU setup | Training time | Fine-tuned Model & Logs | Command to reproduce | |-------|------------------------------|-------------|---------------|---------------|----------------------|-------------| -------------| ------- | | [Multilingual Librispeech](https://huggingface.co/datasets/multilingual_librispeech)| `"german"` | [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) | 0.13 | - | 1 GPU Titan 24 GB RAM | 15h04 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-xlsr-53-300m-mls-german-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-xlsr-53-300m-mls-german-ft/blob/main/run.sh) | | [Multilingual Librispeech](https://huggingface.co/datasets/multilingual_librispeech)| `"german"` | [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 0.15 | - | 1 GPU Titan 24 GB RAM | 15h04 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-300m-mls-german-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-300m-mls-german-ft/blob/main/run.sh) | + +## Sequence to Sequence + +The script [`run_speech_recognition_seq2seq.py`](https://github.com/huggingface/transformers/blob/master/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py) can be used to fine-tune any [Speech Sequence-to-Sequence Model](https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForSpeechSeq2Seq) for automatic speech +recognition on one of the [official speech recognition datasets](https://huggingface.co/datasets?task_ids=task_ids:automatic-speech-recognition) or a custom dataset. + +A very common use case is to leverage a pretrained speech [encoding model](https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModel), +*e.g.* [Wav2Vec2](https://huggingface.co/transformers/master/model_doc/wav2vec2.html), [HuBERT](https://huggingface.co/transformers/master/model_doc/hubert.html), [XLSR-Wav2Vec2](https://huggingface.co/transformers/master/model_doc/xlsr_wav2vec2.html) with a pretrained [text decoding model](https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModel), *e.g.* [Bart](https://huggingface.co/docs/transformers/master/en/model_doc/bart#transformers.BartForCausalLM) to create a [SpeechEnocderDecoderModel](https://huggingface.co/docs/transformers/master/en/model_doc/speechencoderdecoder#speech-encoder-decoder-models). +Consequently, the warm-started Speech-Encoder-Decoder model can be fine-tuned in +this script. + +As an example, let's instantiate a *Wav2Vec2-2-Bart* model with the `SpeechEnocderDecoderModel` framework: + +First create an empty repo on `hf.co`: + +```bash +huggingface-cli repo create wav2vec2-2-bart-base +git clone https://huggingface.co//wav2vec2-2-bart-base +cd wav2vec2-2-bart-base +``` + +Next, run the following script **inside** the just cloned repo: + +```py +from transformers import SpeechEncoderDecoderModel, AutoFeatureExtractor, AutoTokenizer, Wav2Vec2Processor + +# checkpoints to leverage +encoder_id = "facebook/wav2vec2-base" +decoder_id = "facebook/bart-base" + +# load and save speech-encoder-decoder model +# set some hyper-parameters for training and evaluation +model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_add_adapter=True, encoder_feat_proj_dropout=0.0, encoder_layerdrop=0.0, max_length=200, num_beams=5) +model.config.decoder_start_token_id = model.decoder.config.bos_token_id +model.config.pad_token_id = model.decoder.config.pad_token_id +model.config.eos_token_id = model.decoder.config.eos_token_id +model.save_pretrained("./") + +# load and save processor +feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_id) +tokenizer = AutoTokenizer.from_pretrained(decoder_id) +processor = Wav2Vec2Processor(feature_extractor, tokenizer) +processor.save_pretrained("./") +``` + +Finally, we can upload all files: +```bash +git lfs install +git add . && git commit -m "upload model files" && git push +``` + +and link the official `run_speech_recognition_seq2seq.py` script to the folder: + +```bash +ln -s $(realpath /examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py) ./ +``` + +Note that we have added a randomly initialized adapter to `wav2vec2-base` with +`encoder_add_adapter=True` which further samples the output sequence of +`wav2vec2-base` along the time dimension. The reason is that by default a single +output vector of `wav2vec2-base` has a receptive field of *ca.* 25ms (*cf.* with +section *4.2* of the [official Wav2Vec2 paper](https://arxiv.org/pdf/2006.11477.pdf)), which represents a little less a single character. BART on the other hand +makes use of a sentence-piece tokenizer as an input processor so that a single +hidden vector of `bart-base` represents *ca.* 4 characters. To better align +the output of *Wav2Vec2* and *BART*'s hidden vectors for the cross-attention +mechanism, we further subsample *Wav2Vec2*'s output by a factor of 8 by +adding a convolution-based adapter. + +Having warm-started the speech-encoder-decoder model `/wav2vec2-2-bart`, we can now fine-tune it on speech recognition. + +In the script [`run_speech_recognition_seq2seq`], we load the warm-started model, +the feature extractor, and the tokenizer, process a speech recognition dataset, +and then make use of the [`Seq2SeqTrainer`](https://huggingface.co/docs/transformers/master/en/main_classes/trainer#transformers.Seq2SeqTrainer). +Note that it is important to also align the decoder's vocabulary with +the speech transcriptions of the dataset. *E.g.* the [`Librispeech`](https://huggingface.co/datasets/librispeech_asr) has only captilized letters in the transcriptions, +whereas BART was pretrained mostly on normalized text. Thus it is recommended to add +`--do_lower_case` to the fine-tuning script when using a warm-started `SpeechEncoderDecoderModel`. The model is fine-tuned on the standard cross-entropy language modeling +loss for sequence-to-sequence (just like *T5* or *BART* in natural language processing). + +--- +**NOTE** + +If you encounter problems with data preprocessing by setting `--preprocessing_num_workers` > 1, +you might want to set the environment variable `OMP_NUM_THREADS` to 1 as follows: + +```bash +OMP_NUM_THREADS=1 python run_speech_recognition_ctc ... +``` + +If the environment variable is not set, the training script might freeze, *i.e.* see: https://github.com/pytorch/audio/issues/1021#issuecomment-726915239 + +--- + +### Single GPU + +The following command shows how to fine-tune [XLSR-Wav2Vec2](https://huggingface.co/transformers/master/model_doc/xlsr_wav2vec2.html) on [Common Voice](https://huggingface.co/datasets/common_voice) using a single GPU in half-precision. + +```bash +python run_speech_recognition_seq2seq.py \ + --nproc_per_node 8 run_speech_recognition_seq2seq.py \ + --dataset_name="librispeech_asr" \ + --model_name_or_path="./" \ + --dataset_config_name="clean" \ + --train_split_name="train.100" \ + --eval_split_name="validation" \ + --output_dir="./" \ + --preprocessing_num_workers="16" \ + --length_column_name="input_length" \ + --overwrite_output_dir \ + --num_train_epochs="5" \ + --per_device_train_batch_size="8" \ + --per_device_eval_batch_size="8" \ + --gradient_accumulation_steps="8" \ + --learning_rate="3e-4" \ + --warmup_steps="400" \ + --evaluation_strategy="steps" \ + --text_column_name="text" \ + --save_steps="400" \ + --eval_steps="400" \ + --logging_steps="10" \ + --save_total_limit="1" \ + --freeze_feature_extractor \ + --gradient_checkpointing \ + --fp16 \ + --group_by_length \ + --predict_with_generate \ + --generation_max_length="40" \ + --generation_num_beams="1" \ + --do_train --do_eval \ + --do_lower_case +``` + +On a single V100 GPU, this script should run in *ca.* 5 hours and yield a +cross-entropy loss of **0.405** and word error rate of **0.0728**. + +### Multi GPU + +The following command shows how to fine-tune [XLSR-Wav2Vec2](https://huggingface.co/transformers/master/model_doc/xlsr_wav2vec2.html) on [Common Voice](https://huggingface.co/datasets/common_voice) using 8 GPUs in half-precision. + +```bash +python -m torch.distributed.launch \ + --nproc_per_node 8 run_speech_recognition_seq2seq.py \ + --dataset_name="librispeech_asr" \ + --model_name_or_path="./" \ + --dataset_config_name="clean" \ + --train_split_name="train.100" \ + --eval_split_name="validation" \ + --output_dir="./" \ + --preprocessing_num_workers="16" \ + --length_column_name="input_length" \ + --overwrite_output_dir \ + --num_train_epochs="5" \ + --per_device_train_batch_size="8" \ + --per_device_eval_batch_size="8" \ + --gradient_accumulation_steps="1" \ + --learning_rate="3e-4" \ + --warmup_steps="400" \ + --evaluation_strategy="steps" \ + --text_column_name="text" \ + --save_steps="400" \ + --eval_steps="400" \ + --logging_steps="10" \ + --save_total_limit="1" \ + --freeze_feature_extractor \ + --gradient_checkpointing \ + --fp16 \ + --group_by_length \ + --predict_with_generate \ + --do_train --do_eval \ + --do_lower_case +``` + +On 8 V100 GPUs, this script should run in *ca.* 45 minutes and yield a cross-entropy loss of **0.405** and word error rate of **0.0728** + +### Examples + +#### Librispeech + +- [Librispeech](https://huggingface.co/datasets/librispeech_asr) + +| Dataset | Dataset Config | Pretrained Model | Word error rate on eval | Phoneme error rate on eval | GPU setup | Training time | Fine-tuned Model & Logs | Command to reproduce | +|-------|------------------------------|-------------|---------------|---------------|----------------------|-------------| -------------| ------- | +| [Librispeech](https://huggingface.co/datasets/librispeech_asr)| `"clean"` - `"train.100"` | [facebook/wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) and [facebook/bart-base](https://huggingface.co/facebook/bart-base) | 0.0728 | - | 8 GPU V100 | 45min | [here](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-base) | [create_model.py](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-base/blob/main/create_model.py) & [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-base/blob/main/run_librispeech.sh) | +| [Librispeech](https://huggingface.co/datasets/librispeech_asr)| `"clean"` - `"train.100"` | [facebook/wav2vec2-large-lv60](https://huggingface.co/facebook/wav2vec2-large-lv60) and [facebook/bart-large](https://huggingface.co/facebook/bart-large) | 0.0486 | - | 8 GPU V100 | 1h20min | [here](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-large) | [create_model.py](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-large/blob/main/create_model.py) & [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-large/blob/main/run_librispeech.sh) | diff --git a/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py b/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py index 95e017046d..98a4f1374a 100755 --- a/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py +++ b/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py @@ -635,14 +635,13 @@ def main(): return metrics - # Now create a single processor + # Now save everything to be able to create a single processor later if is_main_process(training_args.local_rank): # save feature extractor, tokenizer and config feature_extractor.save_pretrained(training_args.output_dir) tokenizer.save_pretrained(training_args.output_dir) config.save_pretrained(training_args.output_dir) - # load processor try: processor = AutoProcessor.from_pretrained(training_args.output_dir) except (OSError, KeyError): diff --git a/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py b/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py new file mode 100755 index 0000000000..f2a3b1ee4a --- /dev/null +++ b/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py @@ -0,0 +1,502 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for sequence to sequence speech recognition. +""" +# You can also adapt this script on your own sequence to sequence speech +# recognition task. Pointers for this are left as comments. + +import logging +import os +import sys +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Union + +import datasets +import torch +from datasets import DatasetDict, load_dataset, load_metric + +import transformers +from transformers import ( + AutoConfig, + AutoFeatureExtractor, + AutoModelForSpeechSeq2Seq, + AutoProcessor, + AutoTokenizer, + HfArgumentParser, + Seq2SeqTrainer, + Seq2SeqTrainingArguments, + set_seed, +) +from transformers.trainer_utils import get_last_checkpoint, is_main_process +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.16.0.dev0") + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + feature_extractor_name: Optional[str] = field( + default=None, metadata={"help": "feature extractor name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + freeze_feature_extractor: Optional[bool] = field( + default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."} + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + text_column: Optional[str] = field( + default=None, + metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + audio_column_name: Optional[str] = field( + default="audio", + metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"}, + ) + text_column_name: Optional[str] = field( + default="text", + metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"}, + ) + max_duration_in_seconds: Optional[float] = field( + default=20.0, + metadata={ + "help": "Truncate audio files that are longer than `max_duration_in_seconds` seconds to 'max_duration_in_seconds`" + }, + ) + min_duration_in_seconds: Optional[float] = field( + default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"} + ) + preprocessing_only: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to only do data preprocessing and skip training. " + "This is especially useful when data preprocessing errors out in distributed training due to timeout. " + "In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` " + "so that the cached datasets can consequently be loaded in distributed training" + }, + ) + train_split_name: Optional[str] = field( + default="train", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + eval_split_name: Optional[str] = field( + default="test", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + do_lower_case: Optional[bool] = field( + default=True, + metadata={"help": "Whether the target text should be lower cased."}, + ) + + +@dataclass +class DataCollatorSpeechSeq2SeqWithPadding: + """ + Data collator that will dynamically pad the inputs received. + Args: + processor ([`Wav2Vec2Processor`]) + The processor used for proccessing the data. + decoder_start_token_id (`int`) + The begin-of-sentence of the decoder. + """ + + processor: Any + decoder_start_token_id: int + + def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: + # split inputs and labels since they have to be of different lenghts and need + # different padding methods + input_features = [{"input_values": feature["input_values"]} for feature in features] + label_features = [{"input_ids": feature["labels"]} for feature in features] + + batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") + + labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") + + # replace padding with -100 to ignore loss correctly + labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) + + # if bos token is appended in previous tokenization step, + # cut bos token here as it's append later anyways + if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item(): + labels = labels[:, 1:] + + batch["labels"] = labels + + return batch + + +def main(): + # 1. Parse input arguments + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # 2. Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + # Set the verbosity to info of the Transformers logger (on main process only): + if is_main_process(training_args.local_rank): + transformers.utils.logging.set_verbosity_info() + logger.info("Training/evaluation parameters %s", training_args) + + # 3. Detecting last checkpoint and eventualy continue from last checkpoint + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + + # 4. Load dataset + raw_datasets = DatasetDict() + + if training_args.do_train: + raw_datasets["train"] = load_dataset( + data_args.dataset_name, data_args.dataset_config_name, split=data_args.train_split_name + ) + + if training_args.do_eval: + raw_datasets["eval"] = load_dataset( + data_args.dataset_name, data_args.dataset_config_name, split=data_args.eval_split_name + ) + + if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names: + raise ValueError( + f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--audio_column_name` to the correct audio column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names: + raise ValueError( + f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--text_column_name` to the correct text column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + # 5. Load pretrained model, tokenizer, and feature extractor + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + + feature_extractor = AutoFeatureExtractor.from_pretrained( + model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + model = AutoModelForSpeechSeq2Seq.from_pretrained( + model_args.model_name_or_path, + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + + if model.config.decoder_start_token_id is None: + raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") + + if model_args.freeze_feature_extractor: + model.freeze_feature_extractor() + + # 6. Resample speech dataset if necassary + dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate + if dataset_sampling_rate != feature_extractor.sampling_rate: + raw_datasets = raw_datasets.cast_column( + data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate) + ) + + # 7. Preprocessing the datasets. + # We need to read the audio files as arrays and tokenize the targets. + max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate + min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate + audio_column_name = data_args.audio_column_name + num_workers = data_args.preprocessing_num_workers + text_column_name = data_args.text_column_name + model_input_name = feature_extractor.model_input_names[0] + do_lower_case = data_args.do_lower_case + + if data_args.max_train_samples is not None: + raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples)) + + if data_args.max_eval_samples is not None: + raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_train_samples)) + + def prepare_dataset(batch): + # process audio + sample = batch[audio_column_name] + inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"]) + # process audio length + batch[model_input_name] = inputs.input_values[0] + batch["input_length"] = len(batch["input_values"]) + + # process targets + input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name] + batch["labels"] = tokenizer(input_str).input_ids + return batch + + with training_args.main_process_first(desc="dataset map pre-processing"): + vectorized_datasets = raw_datasets.map( + prepare_dataset, + remove_columns=next(iter(raw_datasets.values())).column_names, + num_proc=data_args.preprocessing_num_workers, + desc="preprocess train dataset", + ) + + # filter data that is shorter than min_input_length or longer than + # max_input_length + def is_audio_in_length_range(length): + return length > min_input_length and length < max_input_length + + vectorized_datasets = vectorized_datasets.filter( + is_audio_in_length_range, + num_proc=num_workers, + input_columns=["input_length"], + ) + + # for large datasets it is advised to run the preprocessing on a + # single machine first with `args.preprocessing_only` since there will mostly likely + # be a timeout when running the script in distributed mode. + # In a second step `args.preprocessing_only` can then be set to `False` to load the + # cached dataset + if data_args.preprocessing_only: + cache = {k: v.cache_files for k, v in vectorized_datasets.items()} + logger.info(f"Data preprocessing finished. Files cached at {cache}.") + return + + # 8. Load Metric + metric = load_metric("wer") + + def compute_metrics(pred): + pred_ids = pred.predictions + + pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id + + pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) + # we do not want to group tokens when computing the metrics + label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True) + + wer = metric.compute(predictions=pred_str, references=label_str) + + return {"wer": wer} + + # 9. Create a single speech processor + if is_main_process(training_args.local_rank): + # save feature extractor, tokenizer and config + feature_extractor.save_pretrained(training_args.output_dir) + tokenizer.save_pretrained(training_args.output_dir) + config.save_pretrained(training_args.output_dir) + + processor = AutoProcessor.from_pretrained(training_args.output_dir) + + # 10. Define data collator + data_collator = DataCollatorSpeechSeq2SeqWithPadding( + processor=processor, decoder_start_token_id=model.config.decoder_start_token_id + ) + + # 11. Initialize Trainer + trainer = Seq2SeqTrainer( + model=model, + args=training_args, + train_dataset=vectorized_datasets["train"] if training_args.do_train else None, + eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None, + tokenizer=feature_extractor, + data_collator=data_collator, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + ) + + # 12. Training + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model() # Saves the feature extractor too for easy upload + + metrics = train_result.metrics + max_train_samples = ( + data_args.max_train_samples + if data_args.max_train_samples is not None + else len(vectorized_datasets["train"]) + ) + metrics["train_samples"] = min(max_train_samples, len(vectorized_datasets["train"])) + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # 13. Evaluation + results = {} + if training_args.do_eval: + logger.info("*** Evaluate ***") + metrics = trainer.evaluate( + metric_key_prefix="eval", max_length=model.config.max_length, num_beams=model.config.num_beams + ) + max_eval_samples = ( + data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"]) + ) + metrics["eval_samples"] = min(max_eval_samples, len(vectorized_datasets["eval"])) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # 14. Write Training Stats + kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "speech recognition"} + if data_args.dataset_name is not None: + kwargs["dataset_tags"] = data_args.dataset_name + if data_args.dataset_config_name is not None: + kwargs["dataset_args"] = data_args.dataset_config_name + kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" + else: + kwargs["dataset"] = data_args.dataset_name + + if training_args.push_to_hub: + trainer.push_to_hub(**kwargs) + else: + trainer.create_model_card(**kwargs) + + return results + + +if __name__ == "__main__": + main() diff --git a/examples/pytorch/test_examples.py b/examples/pytorch/test_examples.py index 1a1c2ea06a..f8216175a6 100644 --- a/examples/pytorch/test_examples.py +++ b/examples/pytorch/test_examples.py @@ -59,6 +59,7 @@ if SRC_DIRS is not None: import run_qa as run_squad import run_seq2seq_qa as run_squad_seq2seq import run_speech_recognition_ctc + import run_speech_recognition_seq2seq import run_summarization import run_swag import run_translation @@ -473,6 +474,39 @@ class ExamplesTests(TestCasePlus): result = get_results(tmp_dir) self.assertLess(result["eval_loss"], result["train_loss"]) + def test_run_speech_recognition_seq2seq(self): + stream_handler = logging.StreamHandler(sys.stdout) + logger.addHandler(stream_handler) + + tmp_dir = self.get_auto_remove_tmp_dir() + testargs = f""" + run_speech_recognition_seq2seq.py + --output_dir {tmp_dir} + --model_name_or_path hf-internal-testing/tiny-random-speech-encoder-decoder + --dataset_name hf-internal-testing/librispeech_asr_dummy + --dataset_config_name clean + --train_split_name validation + --eval_split_name validation + --do_train + --do_eval + --learning_rate 1e-4 + --per_device_train_batch_size 2 + --per_device_eval_batch_size 4 + --remove_unused_columns False + --overwrite_output_dir True + --preprocessing_num_workers 16 + --max_steps 10 + --seed 42 + """.split() + + if is_cuda_and_apex_available(): + testargs.append("--fp16") + + with patch.object(sys, "argv", testargs): + run_speech_recognition_seq2seq.main() + result = get_results(tmp_dir) + self.assertLess(result["eval_loss"], result["train_loss"]) + def test_run_audio_classification(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) @@ -521,10 +555,10 @@ class ExamplesTests(TestCasePlus): --dataset_config_names clean --dataset_split_names validation --learning_rate 1e-4 - --per_device_train_batch_size 2 - --per_device_eval_batch_size 2 + --per_device_train_batch_size 4 + --per_device_eval_batch_size 4 --preprocessing_num_workers 16 - --max_train_steps 5 + --max_train_steps 2 --validation_split_percentage 5 --seed 42 """.split() diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index d2c6d496f2..0b046a2835 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -164,7 +164,7 @@ class AutoProcessor: model_type = config_class_to_model_type(type(config).__name__) if getattr(config, "processor_class", None) is not None: - processor_class = config.processor_class + processor_class = processor_class_from_name(config.processor_class) return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs) model_type = config_class_to_model_type(type(config).__name__) diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index eaf3f4b697..0f955d734c 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -905,7 +905,8 @@ class HubertModel(HubertPreTrainedModel): self.feature_extractor = HubertFeatureExtractor(config) self.feature_projection = HubertFeatureProjection(config) - self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_()) + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_()) if config.do_stable_layer_norm: self.encoder = HubertEncoderStableLayerNorm(config) diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 7a467f51df..6b6a8b83a2 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -805,7 +805,8 @@ class SEWModel(SEWPreTrainedModel): self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size) self.feature_dropout = nn.Dropout(config.feat_proj_dropout) - self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_()) + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_()) self.encoder = SEWEncoder(config) diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 1095adbdef..636766663e 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -1341,7 +1341,8 @@ class SEWDModel(SEWDPreTrainedModel): self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size) self.feature_dropout = nn.Dropout(config.feat_proj_dropout) - self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_()) + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_()) self.encoder = SEWDEncoder(config) diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index e3d70dff9f..b632c020ec 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -181,6 +181,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel): config_class = SpeechEncoderDecoderConfig base_model_prefix = "speech_encoder_decoder" main_input_name = "inputs" + supports_gradient_checkpointing = True def __init__( self, @@ -247,6 +248,11 @@ class SpeechEncoderDecoderModel(PreTrainedModel): f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" ) + def _set_gradient_checkpointing(self, module, value=False): + # call both encoder and decoder function on gradient checkpointing + self.encoder._set_gradient_checkpointing(module, value=value) + self.decoder._set_gradient_checkpointing(module, value=value) + def get_encoder(self): return self.encoder @@ -259,6 +265,13 @@ class SpeechEncoderDecoderModel(PreTrainedModel): def set_output_embeddings(self, new_embeddings): return self.decoder.set_output_embeddings(new_embeddings) + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature extractor of the speech encoder so + that its parameters will not be updated during training. + """ + self.encoder.freeze_feature_extractor() + @classmethod def from_pretrained(cls, *args, **kwargs): # At the moment fast initialization is not supported for composite models @@ -367,7 +380,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel): ) if "config" not in kwargs_encoder: - encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path) + encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path, **kwargs_encoder) if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: logger.info( f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " @@ -378,7 +391,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel): kwargs_encoder["config"] = encoder_config - encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder) + encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args) decoder = kwargs_decoder.pop("model", None) if decoder is None: @@ -389,7 +402,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel): ) if "config" not in kwargs_decoder: - decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path) + decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: logger.info( f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. " @@ -411,7 +424,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel): "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" ) - decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) + decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path) # instantiate config with corresponding kwargs config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index c32da4f309..124b3be1d3 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -1052,7 +1052,8 @@ class UniSpeechModel(UniSpeechPreTrainedModel): self.feature_extractor = UniSpeechFeatureExtractor(config) self.feature_projection = UniSpeechFeatureProjection(config) - self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_()) + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_()) if config.do_stable_layer_norm: self.encoder = UniSpeechEncoderStableLayerNorm(config) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 2fd354eba7..f45224ff3b 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -1197,7 +1197,9 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): self.feature_extractor = Wav2Vec2FeatureExtractor(config) self.feature_projection = Wav2Vec2FeatureProjection(config) - self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_()) + # model only needs masking vector if mask prob is > 0.0 + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_()) if config.do_stable_layer_norm: self.encoder = Wav2Vec2EncoderStableLayerNorm(config) @@ -1209,6 +1211,13 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): # Initialize weights and apply final processing self.post_init() + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature extractor so that its parameters + will not be updated during training. + """ + self.feature_extractor._freeze_parameters() + def _mask_hidden_states( self, hidden_states: torch.FloatTensor, diff --git a/src/transformers/models/wav2vec2/processing_wav2vec2.py b/src/transformers/models/wav2vec2/processing_wav2vec2.py index 4cf2a200ff..58fab74102 100644 --- a/src/transformers/models/wav2vec2/processing_wav2vec2.py +++ b/src/transformers/models/wav2vec2/processing_wav2vec2.py @@ -19,6 +19,7 @@ import warnings from contextlib import contextmanager from ...tokenization_utils import PreTrainedTokenizer +from ...tokenization_utils_fast import PreTrainedTokenizerFast from ..auto.tokenization_auto import AutoTokenizer from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer @@ -44,7 +45,7 @@ class Wav2Vec2Processor: raise ValueError( f"`feature_extractor` has to be of type {Wav2Vec2FeatureExtractor.__class__}, but is {type(feature_extractor)}" ) - if not isinstance(tokenizer, PreTrainedTokenizer): + if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): raise ValueError( f"`tokenizer` has to be of type {PreTrainedTokenizer.__class__}, but is {type(tokenizer)}" ) diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 3b33b51d94..088e1671c5 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -1149,7 +1149,9 @@ class WavLMModel(WavLMPreTrainedModel): self.feature_extractor = WavLMFeatureExtractor(config) self.feature_projection = WavLMFeatureProjection(config) - self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_()) + # model only needs masking vector if mask prob is > 0.0 + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_()) if config.do_stable_layer_norm: self.encoder = WavLMEncoderStableLayerNorm(config) @@ -1161,6 +1163,13 @@ class WavLMModel(WavLMPreTrainedModel): # Initialize weights and apply final processing self.post_init() + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature extractor so that its parameters + will not be updated during training. + """ + self.feature_extractor._freeze_parameters() + def _mask_hidden_states( self, hidden_states: torch.FloatTensor,