TTS fine-tuning for SpeechT5 (#21824)
* wrong argument name * append eos_token_id * all tokenizers need mask and ctc_blank tokens * remove reduction factor from feature extractor * add proper TTS loss * did shifting the wrong way around * mask out padded portions * remove logits again (don't really need it) * fix unit tests * fixup * pad also returns the decoder attention mask, since that's useful to have * clean up feature extractor logic * pad can handle TTS task too * remove stop_labels from loss calculation * simplify logic * fixup * do -100 masking properly * small STFT optimization (calculate mel filterbanks only once) * replace torchaudio fbanks with audio_utils * remove torchaudio dependency * simplify & speed up the STFT * don't serialize window and mel filters * output cross attentions when generating speech * add guided attention loss * fix failing test * Update src/transformers/models/speecht5/feature_extraction_speecht5.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/speecht5/modeling_speecht5.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * change type annotation of attention_mask to LongTensor * extract loss into class * remove unused frame_signal_scale argument * use config object in loss class * fix type annotations in doc comments * change optional to just bool * implement missing tokenizer method * add deprecation warning * Update src/transformers/models/speecht5/feature_extraction_speecht5.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/speecht5/feature_extraction_speecht5.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * add deprecation warning for stop_labels --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
dacd34568d
commit
ac2bc50a10
@@ -25,10 +25,10 @@ from transformers.testing_utils import (
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
require_torchaudio,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.trainer_utils import set_seed
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -716,7 +716,6 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
@slow
|
||||
@@ -991,7 +990,6 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
@slow
|
||||
@@ -1005,11 +1003,13 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
|
||||
model.to(torch_device)
|
||||
processor = self.default_processor
|
||||
|
||||
set_seed(555) # make deterministic
|
||||
|
||||
input_text = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
|
||||
input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
generated_speech = model.generate_speech(input_ids)
|
||||
self.assertEqual(generated_speech.shape, (1800, model.config.num_mel_bins))
|
||||
self.assertEqual(generated_speech.shape, (1820, model.config.num_mel_bins))
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -1406,7 +1406,6 @@ class SpeechT5ForSpeechToSpeechTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
@slow
|
||||
|
||||
Reference in New Issue
Block a user