TTS fine-tuning for SpeechT5 (#21824)
* wrong argument name * append eos_token_id * all tokenizers need mask and ctc_blank tokens * remove reduction factor from feature extractor * add proper TTS loss * did shifting the wrong way around * mask out padded portions * remove logits again (don't really need it) * fix unit tests * fixup * pad also returns the decoder attention mask, since that's useful to have * clean up feature extractor logic * pad can handle TTS task too * remove stop_labels from loss calculation * simplify logic * fixup * do -100 masking properly * small STFT optimization (calculate mel filterbanks only once) * replace torchaudio fbanks with audio_utils * remove torchaudio dependency * simplify & speed up the STFT * don't serialize window and mel filters * output cross attentions when generating speech * add guided attention loss * fix failing test * Update src/transformers/models/speecht5/feature_extraction_speecht5.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/speecht5/modeling_speecht5.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * change type annotation of attention_mask to LongTensor * extract loss into class * remove unused frame_signal_scale argument * use config object in loss class * fix type annotations in doc comments * change optional to just bool * implement missing tokenizer method * add deprecation warning * Update src/transformers/models/speecht5/feature_extraction_speecht5.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/speecht5/feature_extraction_speecht5.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * add deprecation warning for stop_labels --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
dacd34568d
commit
ac2bc50a10
@@ -21,7 +21,7 @@ import unittest
|
||||
|
||||
from transformers import is_speech_available, is_torch_available
|
||||
from transformers.models.speecht5 import SpeechT5Tokenizer
|
||||
from transformers.testing_utils import get_tests_dir, require_torch, require_torchaudio
|
||||
from transformers.testing_utils import get_tests_dir, require_torch
|
||||
from transformers.utils import FEATURE_EXTRACTOR_NAME
|
||||
|
||||
|
||||
@@ -35,7 +35,6 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_bpe_char.model")
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
class SpeechT5ProcessorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
@@ -52,7 +51,6 @@ class SpeechT5ProcessorTest(unittest.TestCase):
|
||||
"hop_length": 16,
|
||||
"win_length": 64,
|
||||
"win_function": "hann_window",
|
||||
"frame_signal_scale": 1.0,
|
||||
"fmin": 80,
|
||||
"fmax": 7600,
|
||||
"mel_floor": 1e-10,
|
||||
|
||||
Reference in New Issue
Block a user