TTS fine-tuning for SpeechT5 (#21824)
* wrong argument name * append eos_token_id * all tokenizers need mask and ctc_blank tokens * remove reduction factor from feature extractor * add proper TTS loss * did shifting the wrong way around * mask out padded portions * remove logits again (don't really need it) * fix unit tests * fixup * pad also returns the decoder attention mask, since that's useful to have * clean up feature extractor logic * pad can handle TTS task too * remove stop_labels from loss calculation * simplify logic * fixup * do -100 masking properly * small STFT optimization (calculate mel filterbanks only once) * replace torchaudio fbanks with audio_utils * remove torchaudio dependency * simplify & speed up the STFT * don't serialize window and mel filters * output cross attentions when generating speech * add guided attention loss * fix failing test * Update src/transformers/models/speecht5/feature_extraction_speecht5.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update src/transformers/models/speecht5/modeling_speecht5.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * change type annotation of attention_mask to LongTensor * extract loss into class * remove unused frame_signal_scale argument * use config object in loss class * fix type annotations in doc comments * change optional to just bool * implement missing tokenizer method * add deprecation warning * Update src/transformers/models/speecht5/feature_extraction_speecht5.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/speecht5/feature_extraction_speecht5.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * add deprecation warning for stop_labels --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
dacd34568d
commit
ac2bc50a10
@@ -21,7 +21,7 @@ import unittest
|
||||
import numpy as np
|
||||
|
||||
from transformers import BatchFeature, is_speech_available
|
||||
from transformers.testing_utils import require_torch, require_torchaudio
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.utils.import_utils import is_torch_available
|
||||
|
||||
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
||||
@@ -67,11 +67,9 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase):
|
||||
hop_length=16,
|
||||
win_length=64,
|
||||
win_function="hann_window",
|
||||
frame_signal_scale=1.0,
|
||||
fmin=80,
|
||||
fmax=7600,
|
||||
mel_floor=1e-10,
|
||||
reduction_factor=2,
|
||||
return_attention_mask=True,
|
||||
):
|
||||
self.parent = parent
|
||||
@@ -87,11 +85,9 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase):
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.win_function = win_function
|
||||
self.frame_signal_scale = frame_signal_scale
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
self.mel_floor = mel_floor
|
||||
self.reduction_factor = reduction_factor
|
||||
self.return_attention_mask = return_attention_mask
|
||||
|
||||
def prepare_feat_extract_dict(self):
|
||||
@@ -104,11 +100,9 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase):
|
||||
"hop_length": self.hop_length,
|
||||
"win_length": self.win_length,
|
||||
"win_function": self.win_function,
|
||||
"frame_signal_scale": self.frame_signal_scale,
|
||||
"fmin": self.fmin,
|
||||
"fmax": self.fmax,
|
||||
"mel_floor": self.mel_floor,
|
||||
"reduction_factor": self.reduction_factor,
|
||||
"return_attention_mask": self.return_attention_mask,
|
||||
}
|
||||
|
||||
@@ -147,7 +141,6 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
|
||||
feature_extraction_class = SpeechT5FeatureExtractor if is_speech_available() else None
|
||||
|
||||
@@ -407,10 +400,10 @@ class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
|
||||
def test_integration_target(self):
|
||||
# fmt: off
|
||||
EXPECTED_INPUT_VALUES = torch.tensor(
|
||||
[-2.7713, -2.8896, -3.2619, -3.0843, -2.9919, -3.0084, -3.2796, -3.3169,
|
||||
-3.2397, -3.2053, -2.9151, -2.7921, -2.9403, -2.7411, -3.0654, -2.8314,
|
||||
-3.0026, -2.9797, -3.1314, -2.9939, -2.6748, -2.7725, -2.8563, -2.9462,
|
||||
-3.2623, -3.3044, -3.1318, -3.2672, -3.4030, -3.1988]
|
||||
[-2.6870, -3.0104, -3.1356, -3.5352, -3.0044, -3.0353, -3.4719, -3.6777,
|
||||
-3.1520, -2.9435, -2.6553, -2.8795, -2.9944, -2.5921, -3.0279, -3.0386,
|
||||
-3.0864, -3.1291, -3.2353, -2.7444, -2.6831, -2.7287, -3.1761, -3.1571,
|
||||
-3.2726, -3.0582, -3.1007, -3.4533, -3.4695, -3.0998]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
|
||||
@@ -25,10 +25,10 @@ from transformers.testing_utils import (
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
require_torchaudio,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.trainer_utils import set_seed
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -716,7 +716,6 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
@slow
|
||||
@@ -991,7 +990,6 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
@slow
|
||||
@@ -1005,11 +1003,13 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
|
||||
model.to(torch_device)
|
||||
processor = self.default_processor
|
||||
|
||||
set_seed(555) # make deterministic
|
||||
|
||||
input_text = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
|
||||
input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
generated_speech = model.generate_speech(input_ids)
|
||||
self.assertEqual(generated_speech.shape, (1800, model.config.num_mel_bins))
|
||||
self.assertEqual(generated_speech.shape, (1820, model.config.num_mel_bins))
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -1406,7 +1406,6 @@ class SpeechT5ForSpeechToSpeechTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
@slow
|
||||
|
||||
@@ -21,7 +21,7 @@ import unittest
|
||||
|
||||
from transformers import is_speech_available, is_torch_available
|
||||
from transformers.models.speecht5 import SpeechT5Tokenizer
|
||||
from transformers.testing_utils import get_tests_dir, require_torch, require_torchaudio
|
||||
from transformers.testing_utils import get_tests_dir, require_torch
|
||||
from transformers.utils import FEATURE_EXTRACTOR_NAME
|
||||
|
||||
|
||||
@@ -35,7 +35,6 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_bpe_char.model")
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
class SpeechT5ProcessorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
@@ -52,7 +51,6 @@ class SpeechT5ProcessorTest(unittest.TestCase):
|
||||
"hop_length": 16,
|
||||
"win_length": 64,
|
||||
"win_function": "hann_window",
|
||||
"frame_signal_scale": 1.0,
|
||||
"fmin": 80,
|
||||
"fmax": 7600,
|
||||
"mel_floor": 1e-10,
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user