[WIP] add SpeechT5 model (#18922)

* make SpeechT5 model by copying Wav2Vec2

* add paper to docs

* whoops added docs in wrong file

* remove SpeechT5Tokenizer + put CTC back in the name

* remove deprecated class

* remove unused docstring

* delete SpeechT5FeatureExtractor, use Wav2Vec2FeatureExtractor instead

* remove classes we don't need right now

* initial stab at speech encoder prenet

* add more speech encoder prenet stuff

* improve SpeechEncoderPrenet

* add encoder (not finished yet)

* add relative position bias to self-attention

* add encoder CTC layers

* fix formatting

* add decoder from BART, doesn't work yet

* make it work with generate loop

* wrap the encoder into a speech encoder class

* wrap the decoder in a text decoder class

* changed my mind

* changed my mind again ;-)

* load decoder weights, make it work

* add weights for text decoder postnet

* add SpeechT5ForCTC model that uses only the encoder

* clean up EncoderLayer and DecoderLayer

* implement _init_weights in SpeechT5PreTrainedModel

* cleanup config + Encoder and Decoder

* add head + cross attention masks

* improve doc comments

* fixup

* more cleanup

* more fixup

* TextDecoderPrenet works now, thanks Kendall

* add CTC loss

* add placeholders for other pre/postnets

* add type annotation

* fix freeze_feature_encoder

* set padding tokens to 0 in decoder attention mask

* encoder attention mask downsampling

* remove features_pen calculation

* disable the padding tokens thing again

* fixup

* more fixup

* code review fixes

* rename encoder/decoder wrapper classes

* allow checkpoints to be loaded into SpeechT5Model

* put encoder into wrapper for CTC model

* clean up conversion script

* add encoder for TTS model

* add speech decoder prenet

* add speech decoder post-net

* attempt to reconstruct the generation loop

* add speech generation loop

* clean up generate_speech

* small tweaks

* fix forward pass

* enable always dropout on speech decoder prenet

* sort declaration

* rename models

* fixup

* fix copies

* more fixup

* make consistency checker happy

* add Seq2SeqSpectrogramOutput class

* doc comments

* quick note about loss and labels

* add HiFi-GAN implementation (from Speech2Speech PR)

* rename file

* add vocoder to TTS model

* improve vocoder

* working on tokenizer

* more better tokenizer

* add CTC tokenizer

* fix decode and batch_code in CTC tokenizer

* fix processor

* two processors and feature extractors

* use SpeechT5WaveformFeatureExtractor instead of Wav2Vec2

* cleanup

* more cleanup

* even more fixup

* notebooks

* fix log-mel spectrograms

* support reduction factor

* fixup

* shift spectrograms to right to create decoder inputs

* return correct labels

* add labels for stop token prediction

* fix doc comments

* fixup

* remove SpeechT5ForPreTraining

* more fixup

* update copyright headers

* add usage examples

* add SpeechT5ProcessorForCTC

* fixup

* push unofficial checkpoints to hub

* initial version of tokenizer unit tests

* add slow test

* fix failing tests

* tests for CTC tokenizer

* finish CTC tokenizer tests

* processor tests

* initial test for feature extractors

* tests for spectrogram feature extractor

* fixup

* more fixup

* add decorators

* require speech for tests

* modeling tests

* more tests for ASR model

* fix imports

* add fake tests for the other models

* fixup

* remove jupyter notebooks

* add missing SpeechT5Model tests

* add missing tests for SpeechT5ForCTC

* add missing tests for SpeechT5ForTextToSpeech

* sort tests by name

* fix Hi-Fi GAN tests

* fixup

* add speech-to-speech model

* refactor duplicate speech generation code

* add processor for SpeechToSpeech model

* add usage example

* add tests for speech-to-speech model

* fixup

* enable gradient checkpointing for SpeechT5FeatureEncoder

* code review

* push_to_hub now takes repo_id

* improve doc comments for HiFi-GAN config

* add missing test

* add integration tests

* make number of layers in speech decoder prenet configurable

* rename variable

* rename variables

* add auto classes for TTS and S2S

* REMOVE CTC!!!

* S2S processor does not support save/load_pretrained

* fixup

* these models are now in an auto mapping

* fix doc links

* rename HiFiGAN to HifiGan, remove separate config file

* REMOVE auto classes

* there can be only one

* fixup

* replace assert

* reformat

* feature extractor can process input and target at same time

* update checkpoint names

* fix commit hash
This commit is contained in:
Matthijs Hollemans
2023-02-03 18:43:46 +01:00
committed by GitHub
parent fb13a7df95
commit e4bacf6614
39 changed files with 7545 additions and 14 deletions

View File

@@ -133,6 +133,18 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
"BlipTextLMHeadModel", # No need to test it as it is tested by BlipTextVision models
"BridgeTowerTextModel", # No need to test it as it is tested by BridgeTowerModel model.
"BridgeTowerVisionModel", # No need to test it as it is tested by BridgeTowerModel model.
"SpeechT5Decoder", # Building part of bigger (tested) model.
"SpeechT5DecoderWithoutPrenet", # Building part of bigger (tested) model.
"SpeechT5DecoderWithSpeechPrenet", # Building part of bigger (tested) model.
"SpeechT5DecoderWithTextPrenet", # Building part of bigger (tested) model.
"SpeechT5Encoder", # Building part of bigger (tested) model.
"SpeechT5EncoderWithoutPrenet", # Building part of bigger (tested) model.
"SpeechT5EncoderWithSpeechPrenet", # Building part of bigger (tested) model.
"SpeechT5EncoderWithTextPrenet", # Building part of bigger (tested) model.
"SpeechT5SpeechDecoder", # Building part of bigger (tested) model.
"SpeechT5SpeechEncoder", # Building part of bigger (tested) model.
"SpeechT5TextDecoder", # Building part of bigger (tested) model.
"SpeechT5TextEncoder", # Building part of bigger (tested) model.
]
# Update this list with test files that don't have a tester with a `all_model_classes` variable and which don't
@@ -269,6 +281,9 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"AltCLIPTextModel",
"AltCLIPVisionModel",
"AltRobertaModel",
"SpeechT5ForSpeechToSpeech",
"SpeechT5ForTextToSpeech",
"SpeechT5HifiGan",
]
# Update this list for models that have multiple model types for the same
@@ -378,6 +393,8 @@ def is_a_private_model(model):
return True
if model.endswith("Decoder"):
return True
if model.endswith("Prenet"):
return True
return False