TTS fine-tuning for SpeechT5 (#21824)

* wrong argument name

* append eos_token_id

* all tokenizers need mask and ctc_blank tokens

* remove reduction factor from feature extractor

* add proper TTS loss

* did shifting the wrong way around

* mask out padded portions

* remove logits again (don't really need it)

* fix unit tests

* fixup

* pad also returns the decoder attention mask, since that's useful to have

* clean up feature extractor logic

* pad can handle TTS task too

* remove stop_labels from loss calculation

* simplify logic

* fixup

* do -100 masking properly

* small STFT optimization (calculate mel filterbanks only once)

* replace torchaudio fbanks with audio_utils

* remove torchaudio dependency

* simplify & speed up the STFT

* don't serialize window and mel filters

* output cross attentions when generating speech

* add guided attention loss

* fix failing test

* Update src/transformers/models/speecht5/feature_extraction_speecht5.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/speecht5/modeling_speecht5.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* change type annotation of attention_mask to LongTensor

* extract loss into class

* remove unused frame_signal_scale argument

* use config object in loss class

* fix type annotations in doc comments

* change optional to just bool

* implement missing tokenizer method

* add deprecation warning

* Update src/transformers/models/speecht5/feature_extraction_speecht5.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/speecht5/feature_extraction_speecht5.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add deprecation warning for stop_labels

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Matthijs Hollemans
2023-04-18 11:12:30 +02:00
committed by GitHub
parent dacd34568d
commit ac2bc50a10
10 changed files with 448 additions and 234 deletions

View File

@@ -25,10 +25,10 @@ from transformers.testing_utils import (
require_sentencepiece,
require_tokenizers,
require_torch,
require_torchaudio,
slow,
torch_device,
)
from transformers.trainer_utils import set_seed
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@@ -716,7 +716,6 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
@require_torch
@require_torchaudio
@require_sentencepiece
@require_tokenizers
@slow
@@ -991,7 +990,6 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase):
@require_torch
@require_torchaudio
@require_sentencepiece
@require_tokenizers
@slow
@@ -1005,11 +1003,13 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
model.to(torch_device)
processor = self.default_processor
set_seed(555) # make deterministic
input_text = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device)
generated_speech = model.generate_speech(input_ids)
self.assertEqual(generated_speech.shape, (1800, model.config.num_mel_bins))
self.assertEqual(generated_speech.shape, (1820, model.config.num_mel_bins))
@require_torch
@@ -1406,7 +1406,6 @@ class SpeechT5ForSpeechToSpeechTest(ModelTesterMixin, unittest.TestCase):
@require_torch
@require_torchaudio
@require_sentencepiece
@require_tokenizers
@slow