TTS fine-tuning for SpeechT5 (#21824)

* wrong argument name

* append eos_token_id

* all tokenizers need mask and ctc_blank tokens

* remove reduction factor from feature extractor

* add proper TTS loss

* did shifting the wrong way around

* mask out padded portions

* remove logits again (don't really need it)

* fix unit tests

* fixup

* pad also returns the decoder attention mask, since that's useful to have

* clean up feature extractor logic

* pad can handle TTS task too

* remove stop_labels from loss calculation

* simplify logic

* fixup

* do -100 masking properly

* small STFT optimization (calculate mel filterbanks only once)

* replace torchaudio fbanks with audio_utils

* remove torchaudio dependency

* simplify & speed up the STFT

* don't serialize window and mel filters

* output cross attentions when generating speech

* add guided attention loss

* fix failing test

* Update src/transformers/models/speecht5/feature_extraction_speecht5.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/speecht5/modeling_speecht5.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* change type annotation of attention_mask to LongTensor

* extract loss into class

* remove unused frame_signal_scale argument

* use config object in loss class

* fix type annotations in doc comments

* change optional to just bool

* implement missing tokenizer method

* add deprecation warning

* Update src/transformers/models/speecht5/feature_extraction_speecht5.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/speecht5/feature_extraction_speecht5.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add deprecation warning for stop_labels

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Matthijs Hollemans
2023-04-18 11:12:30 +02:00
committed by GitHub
parent dacd34568d
commit ac2bc50a10
10 changed files with 448 additions and 234 deletions

View File

@@ -21,7 +21,7 @@ import unittest
from transformers import is_speech_available, is_torch_available
from transformers.models.speecht5 import SpeechT5Tokenizer
from transformers.testing_utils import get_tests_dir, require_torch, require_torchaudio
from transformers.testing_utils import get_tests_dir, require_torch
from transformers.utils import FEATURE_EXTRACTOR_NAME
@@ -35,7 +35,6 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_bpe_char.model")
@require_torch
@require_torchaudio
class SpeechT5ProcessorTest(unittest.TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
@@ -52,7 +51,6 @@ class SpeechT5ProcessorTest(unittest.TestCase):
"hop_length": 16,
"win_length": 64,
"win_function": "hann_window",
"frame_signal_scale": 1.0,
"fmin": 80,
"fmax": 7600,
"mel_floor": 1e-10,