TTS fine-tuning for SpeechT5 (#21824)

* wrong argument name

* append eos_token_id

* all tokenizers need mask and ctc_blank tokens

* remove reduction factor from feature extractor

* add proper TTS loss

* did shifting the wrong way around

* mask out padded portions

* remove logits again (don't really need it)

* fix unit tests

* fixup

* pad also returns the decoder attention mask, since that's useful to have

* clean up feature extractor logic

* pad can handle TTS task too

* remove stop_labels from loss calculation

* simplify logic

* fixup

* do -100 masking properly

* small STFT optimization (calculate mel filterbanks only once)

* replace torchaudio fbanks with audio_utils

* remove torchaudio dependency

* simplify & speed up the STFT

* don't serialize window and mel filters

* output cross attentions when generating speech

* add guided attention loss

* fix failing test

* Update src/transformers/models/speecht5/feature_extraction_speecht5.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/speecht5/modeling_speecht5.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* change type annotation of attention_mask to LongTensor

* extract loss into class

* remove unused frame_signal_scale argument

* use config object in loss class

* fix type annotations in doc comments

* change optional to just bool

* implement missing tokenizer method

* add deprecation warning

* Update src/transformers/models/speecht5/feature_extraction_speecht5.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/speecht5/feature_extraction_speecht5.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add deprecation warning for stop_labels

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Matthijs Hollemans
2023-04-18 11:12:30 +02:00
committed by GitHub
parent dacd34568d
commit ac2bc50a10
10 changed files with 448 additions and 234 deletions

View File

@@ -21,7 +21,7 @@ import unittest
import numpy as np
from transformers import BatchFeature, is_speech_available
from transformers.testing_utils import require_torch, require_torchaudio
from transformers.testing_utils import require_torch
from transformers.utils.import_utils import is_torch_available
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
@@ -67,11 +67,9 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase):
hop_length=16,
win_length=64,
win_function="hann_window",
frame_signal_scale=1.0,
fmin=80,
fmax=7600,
mel_floor=1e-10,
reduction_factor=2,
return_attention_mask=True,
):
self.parent = parent
@@ -87,11 +85,9 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase):
self.hop_length = hop_length
self.win_length = win_length
self.win_function = win_function
self.frame_signal_scale = frame_signal_scale
self.fmin = fmin
self.fmax = fmax
self.mel_floor = mel_floor
self.reduction_factor = reduction_factor
self.return_attention_mask = return_attention_mask
def prepare_feat_extract_dict(self):
@@ -104,11 +100,9 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase):
"hop_length": self.hop_length,
"win_length": self.win_length,
"win_function": self.win_function,
"frame_signal_scale": self.frame_signal_scale,
"fmin": self.fmin,
"fmax": self.fmax,
"mel_floor": self.mel_floor,
"reduction_factor": self.reduction_factor,
"return_attention_mask": self.return_attention_mask,
}
@@ -147,7 +141,6 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase):
@require_torch
@require_torchaudio
class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = SpeechT5FeatureExtractor if is_speech_available() else None
@@ -407,10 +400,10 @@ class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
def test_integration_target(self):
# fmt: off
EXPECTED_INPUT_VALUES = torch.tensor(
[-2.7713, -2.8896, -3.2619, -3.0843, -2.9919, -3.0084, -3.2796, -3.3169,
-3.2397, -3.2053, -2.9151, -2.7921, -2.9403, -2.7411, -3.0654, -2.8314,
-3.0026, -2.9797, -3.1314, -2.9939, -2.6748, -2.7725, -2.8563, -2.9462,
-3.2623, -3.3044, -3.1318, -3.2672, -3.4030, -3.1988]
[-2.6870, -3.0104, -3.1356, -3.5352, -3.0044, -3.0353, -3.4719, -3.6777,
-3.1520, -2.9435, -2.6553, -2.8795, -2.9944, -2.5921, -3.0279, -3.0386,
-3.0864, -3.1291, -3.2353, -2.7444, -2.6831, -2.7287, -3.1761, -3.1571,
-3.2726, -3.0582, -3.1007, -3.4533, -3.4695, -3.0998]
)
# fmt: on

View File

@@ -25,10 +25,10 @@ from transformers.testing_utils import (
require_sentencepiece,
require_tokenizers,
require_torch,
require_torchaudio,
slow,
torch_device,
)
from transformers.trainer_utils import set_seed
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@@ -716,7 +716,6 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
@require_torch
@require_torchaudio
@require_sentencepiece
@require_tokenizers
@slow
@@ -991,7 +990,6 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase):
@require_torch
@require_torchaudio
@require_sentencepiece
@require_tokenizers
@slow
@@ -1005,11 +1003,13 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
model.to(torch_device)
processor = self.default_processor
set_seed(555) # make deterministic
input_text = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device)
generated_speech = model.generate_speech(input_ids)
self.assertEqual(generated_speech.shape, (1800, model.config.num_mel_bins))
self.assertEqual(generated_speech.shape, (1820, model.config.num_mel_bins))
@require_torch
@@ -1406,7 +1406,6 @@ class SpeechT5ForSpeechToSpeechTest(ModelTesterMixin, unittest.TestCase):
@require_torch
@require_torchaudio
@require_sentencepiece
@require_tokenizers
@slow

View File

@@ -21,7 +21,7 @@ import unittest
from transformers import is_speech_available, is_torch_available
from transformers.models.speecht5 import SpeechT5Tokenizer
from transformers.testing_utils import get_tests_dir, require_torch, require_torchaudio
from transformers.testing_utils import get_tests_dir, require_torch
from transformers.utils import FEATURE_EXTRACTOR_NAME
@@ -35,7 +35,6 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_bpe_char.model")
@require_torch
@require_torchaudio
class SpeechT5ProcessorTest(unittest.TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
@@ -52,7 +51,6 @@ class SpeechT5ProcessorTest(unittest.TestCase):
"hop_length": 16,
"win_length": 64,
"win_function": "hann_window",
"frame_signal_scale": 1.0,
"fmin": 80,
"fmax": 7600,
"mel_floor": 1e-10,

File diff suppressed because one or more lines are too long