From 4e945660187d1789cda8463f1ff3786a2dbd7585 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Tue, 4 Jul 2023 16:03:27 +0100 Subject: [PATCH] Fix audio feature extractor deps (#24636) * Fix audio feature extractor deps * use audio utils window over torch window --- src/transformers/__init__.py | 19 ++++++------- src/transformers/models/mctct/__init__.py | 21 ++------------ .../models/mctct/feature_extraction_mctct.py | 9 ++---- src/transformers/models/speecht5/__init__.py | 19 ++----------- .../speecht5/feature_extraction_speecht5.py | 6 ++-- src/transformers/models/tvlt/__init__.py | 17 ++--------- .../utils/dummy_speech_objects.py | 28 ------------------- .../test_feature_extraction_encodec.py | 7 ++--- .../mctct/test_feature_extraction_mctct.py | 7 ++--- .../test_feature_extraction_speecht5.py | 7 ++--- .../tvlt/test_feature_extraction_tvlt.py | 7 ++--- 11 files changed, 28 insertions(+), 119 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 882ac752b2..9e6fa1bf38 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -285,6 +285,7 @@ _import_structure = { "models.encodec": [ "ENCODEC_PRETRAINED_CONFIG_ARCHIVE_MAP", "EncodecConfig", + "EncodecFeatureExtractor", ], "models.encoder_decoder": ["EncoderDecoderConfig"], "models.ernie": [ @@ -388,7 +389,7 @@ _import_structure = { "models.maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig", "MaskFormerSwinConfig"], "models.mbart": ["MBartConfig"], "models.mbart50": [], - "models.mctct": ["MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MCTCTConfig", "MCTCTProcessor"], + "models.mctct": ["MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MCTCTConfig", "MCTCTFeatureExtractor", "MCTCTProcessor"], "models.mega": ["MEGA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegaConfig"], "models.megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"], "models.megatron_gpt2": [], @@ -481,6 +482,7 @@ _import_structure = { "SPEECHT5_PRETRAINED_CONFIG_ARCHIVE_MAP", "SPEECHT5_PRETRAINED_HIFIGAN_CONFIG_ARCHIVE_MAP", "SpeechT5Config", + "SpeechT5FeatureExtractor", "SpeechT5HifiGanConfig", "SpeechT5Processor", ], @@ -519,6 +521,7 @@ _import_structure = { "models.tvlt": [ "TVLT_PRETRAINED_CONFIG_ARCHIVE_MAP", "TvltConfig", + "TvltFeatureExtractor", "TvltProcessor", ], "models.umt5": [], @@ -843,11 +846,7 @@ except OptionalDependencyNotAvailable: ] else: _import_structure["models.audio_spectrogram_transformer"].append("ASTFeatureExtractor") - _import_structure["models.encodec"].append("EncodecFeatureExtractor") - _import_structure["models.mctct"].append("MCTCTFeatureExtractor") _import_structure["models.speech_to_text"].append("Speech2TextFeatureExtractor") - _import_structure["models.speecht5"].append("SpeechT5FeatureExtractor") - _import_structure["models.tvlt"].append("TvltFeatureExtractor") # Tensorflow-text-specific objects try: @@ -4170,6 +4169,7 @@ if TYPE_CHECKING: from .models.encodec import ( ENCODEC_PRETRAINED_CONFIG_ARCHIVE_MAP, EncodecConfig, + EncodecFeatureExtractor, ) from .models.encoder_decoder import EncoderDecoderConfig from .models.ernie import ERNIE_PRETRAINED_CONFIG_ARCHIVE_MAP, ErnieConfig @@ -4265,7 +4265,7 @@ if TYPE_CHECKING: from .models.mask2former import MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, Mask2FormerConfig from .models.maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig, MaskFormerSwinConfig from .models.mbart import MBartConfig - from .models.mctct import MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP, MCTCTConfig, MCTCTProcessor + from .models.mctct import MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP, MCTCTConfig, MCTCTFeatureExtractor, MCTCTProcessor from .models.mega import MEGA_PRETRAINED_CONFIG_ARCHIVE_MAP, MegaConfig from .models.megatron_bert import MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronBertConfig from .models.mgp_str import MGP_STR_PRETRAINED_CONFIG_ARCHIVE_MAP, MgpstrConfig, MgpstrProcessor, MgpstrTokenizer @@ -4355,6 +4355,7 @@ if TYPE_CHECKING: SPEECHT5_PRETRAINED_CONFIG_ARCHIVE_MAP, SPEECHT5_PRETRAINED_HIFIGAN_CONFIG_ARCHIVE_MAP, SpeechT5Config, + SpeechT5FeatureExtractor, SpeechT5HifiGanConfig, SpeechT5Processor, ) @@ -4386,7 +4387,7 @@ if TYPE_CHECKING: TransfoXLTokenizer, ) from .models.trocr import TROCR_PRETRAINED_CONFIG_ARCHIVE_MAP, TrOCRConfig, TrOCRProcessor - from .models.tvlt import TVLT_PRETRAINED_CONFIG_ARCHIVE_MAP, TvltConfig, TvltProcessor + from .models.tvlt import TVLT_PRETRAINED_CONFIG_ARCHIVE_MAP, TvltConfig, TvltFeatureExtractor, TvltProcessor from .models.unispeech import UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechConfig from .models.unispeech_sat import UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP, UniSpeechSatConfig from .models.upernet import UperNetConfig @@ -4681,11 +4682,7 @@ if TYPE_CHECKING: from .utils.dummy_speech_objects import * else: from .models.audio_spectrogram_transformer import ASTFeatureExtractor - from .models.encodec import EncodecFeatureExtractor - from .models.mctct import MCTCTFeatureExtractor from .models.speech_to_text import Speech2TextFeatureExtractor - from .models.speecht5 import SpeechT5FeatureExtractor - from .models.tvlt import TvltFeatureExtractor try: if not is_tensorflow_text_available(): diff --git a/src/transformers/models/mctct/__init__.py b/src/transformers/models/mctct/__init__.py index 5da754fc51..f389ebd9e7 100644 --- a/src/transformers/models/mctct/__init__.py +++ b/src/transformers/models/mctct/__init__.py @@ -13,24 +13,16 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_speech_available, is_torch_available +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available _import_structure = { "configuration_mctct": ["MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MCTCTConfig"], + "feature_extraction_mctct": ["MCTCTFeatureExtractor"], "processing_mctct": ["MCTCTProcessor"], } -try: - if not is_speech_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["feature_extraction_mctct"] = ["MCTCTFeatureExtractor"] - - try: if not is_torch_available(): raise OptionalDependencyNotAvailable() @@ -47,16 +39,9 @@ else: if TYPE_CHECKING: from .configuration_mctct import MCTCT_PRETRAINED_CONFIG_ARCHIVE_MAP, MCTCTConfig + from .feature_extraction_mctct import MCTCTFeatureExtractor from .processing_mctct import MCTCTProcessor - try: - if not is_speech_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .feature_extraction_mctct import MCTCTFeatureExtractor - try: if not is_torch_available(): raise OptionalDependencyNotAvailable() diff --git a/src/transformers/models/mctct/feature_extraction_mctct.py b/src/transformers/models/mctct/feature_extraction_mctct.py index 9e9e276c16..23ae02ecad 100644 --- a/src/transformers/models/mctct/feature_extraction_mctct.py +++ b/src/transformers/models/mctct/feature_extraction_mctct.py @@ -19,9 +19,8 @@ Feature extractor class for M-CTC-T from typing import List, Optional, Union import numpy as np -import torch -from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram +from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram, window_function from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_utils import BatchFeature from ...file_utils import PaddingStrategy, TensorType @@ -110,11 +109,9 @@ class MCTCTFeatureExtractor(SequenceFeatureExtractor): Extracts MFSC Features for one waveform vector (unbatched). Adapted from Flashlight's C++ MFSC code. """ if self.win_function == "hamming_window": - window = torch.hamming_window(window_length=self.sample_size, periodic=False, alpha=0.54, beta=0.46) + window = window_function(window_length=self.sample_size, name=self.win_function, periodic=False) else: - window = getattr(torch, self.win_function)() - - window = window.numpy() + window = window_function(window_length=self.sample_size, name=self.win_function) fbanks = mel_filter_bank( num_frequency_bins=self.n_freqs, diff --git a/src/transformers/models/speecht5/__init__.py b/src/transformers/models/speecht5/__init__.py index d1f8a6d8f7..20606dda51 100644 --- a/src/transformers/models/speecht5/__init__.py +++ b/src/transformers/models/speecht5/__init__.py @@ -17,7 +17,6 @@ from ...utils import ( OptionalDependencyNotAvailable, _LazyModule, is_sentencepiece_available, - is_speech_available, is_torch_available, ) @@ -29,6 +28,7 @@ _import_structure = { "SpeechT5Config", "SpeechT5HifiGanConfig", ], + "feature_extraction_speecht5": ["SpeechT5FeatureExtractor"], "processing_speecht5": ["SpeechT5Processor"], } @@ -40,14 +40,6 @@ except OptionalDependencyNotAvailable: else: _import_structure["tokenization_speecht5"] = ["SpeechT5Tokenizer"] -try: - if not is_speech_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["feature_extraction_speecht5"] = ["SpeechT5FeatureExtractor"] - try: if not is_torch_available(): raise OptionalDependencyNotAvailable() @@ -71,6 +63,7 @@ if TYPE_CHECKING: SpeechT5Config, SpeechT5HifiGanConfig, ) + from .feature_extraction_speecht5 import SpeechT5FeatureExtractor from .processing_speecht5 import SpeechT5Processor try: @@ -81,14 +74,6 @@ if TYPE_CHECKING: else: from .tokenization_speecht5 import SpeechT5Tokenizer - try: - if not is_speech_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .feature_extraction_speecht5 import SpeechT5FeatureExtractor - try: if not is_torch_available(): raise OptionalDependencyNotAvailable() diff --git a/src/transformers/models/speecht5/feature_extraction_speecht5.py b/src/transformers/models/speecht5/feature_extraction_speecht5.py index dd5ff4c8a1..84d51e97df 100644 --- a/src/transformers/models/speecht5/feature_extraction_speecht5.py +++ b/src/transformers/models/speecht5/feature_extraction_speecht5.py @@ -18,9 +18,8 @@ import warnings from typing import Any, Dict, List, Optional, Union import numpy as np -import torch -from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram +from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram, window_function from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_utils import BatchFeature from ...utils import PaddingStrategy, TensorType, logging @@ -113,8 +112,7 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor): self.n_fft = optimal_fft_length(self.sample_size) self.n_freqs = (self.n_fft // 2) + 1 - window = getattr(torch, self.win_function)(window_length=self.sample_size, periodic=True) - self.window = window.numpy().astype(np.float64) + self.window = window_function(window_length=self.sample_size, name=self.win_function, periodic=True) self.mel_filters = mel_filter_bank( num_frequency_bins=self.n_freqs, diff --git a/src/transformers/models/tvlt/__init__.py b/src/transformers/models/tvlt/__init__.py index 5ca90bb744..86c0f7c1c0 100644 --- a/src/transformers/models/tvlt/__init__.py +++ b/src/transformers/models/tvlt/__init__.py @@ -20,7 +20,6 @@ from typing import TYPE_CHECKING from ...utils import ( OptionalDependencyNotAvailable, _LazyModule, - is_speech_available, is_torch_available, is_vision_available, ) @@ -28,6 +27,7 @@ from ...utils import ( _import_structure = { "configuration_tvlt": ["TVLT_PRETRAINED_CONFIG_ARCHIVE_MAP", "TvltConfig"], + "feature_extraction_tvlt": ["TvltFeatureExtractor"], "processing_tvlt": ["TvltProcessor"], } @@ -53,17 +53,11 @@ except OptionalDependencyNotAvailable: else: _import_structure["image_processing_tvlt"] = ["TvltImageProcessor"] -try: - if not is_speech_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["feature_extraction_tvlt"] = ["TvltFeatureExtractor"] if TYPE_CHECKING: from .configuration_tvlt import TVLT_PRETRAINED_CONFIG_ARCHIVE_MAP, TvltConfig from .processing_tvlt import TvltProcessor + from .feature_extraction_tvlt import TvltFeatureExtractor try: if not is_torch_available(): @@ -87,13 +81,6 @@ if TYPE_CHECKING: else: from .image_processing_tvlt import TvltImageProcessor - try: - if not is_speech_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .feature_extraction_tvlt import TvltFeatureExtractor else: import sys diff --git a/src/transformers/utils/dummy_speech_objects.py b/src/transformers/utils/dummy_speech_objects.py index 6c59158858..0bf08ebea4 100644 --- a/src/transformers/utils/dummy_speech_objects.py +++ b/src/transformers/utils/dummy_speech_objects.py @@ -9,36 +9,8 @@ class ASTFeatureExtractor(metaclass=DummyObject): requires_backends(self, ["speech"]) -class EncodecFeatureExtractor(metaclass=DummyObject): - _backends = ["speech"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["speech"]) - - -class MCTCTFeatureExtractor(metaclass=DummyObject): - _backends = ["speech"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["speech"]) - - class Speech2TextFeatureExtractor(metaclass=DummyObject): _backends = ["speech"] def __init__(self, *args, **kwargs): requires_backends(self, ["speech"]) - - -class SpeechT5FeatureExtractor(metaclass=DummyObject): - _backends = ["speech"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["speech"]) - - -class TvltFeatureExtractor(metaclass=DummyObject): - _backends = ["speech"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["speech"]) diff --git a/tests/models/encodec/test_feature_extraction_encodec.py b/tests/models/encodec/test_feature_extraction_encodec.py index 95639fcda5..bf40135c32 100644 --- a/tests/models/encodec/test_feature_extraction_encodec.py +++ b/tests/models/encodec/test_feature_extraction_encodec.py @@ -20,16 +20,13 @@ import unittest import numpy as np -from transformers import is_speech_available +from transformers import EncodecFeatureExtractor from transformers.testing_utils import require_torch from transformers.utils.import_utils import is_torch_available from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin -if is_speech_available(): - from transformers import EncodecFeatureExtractor - if is_torch_available(): import torch @@ -103,7 +100,7 @@ class EnCodecFeatureExtractionTester(unittest.TestCase): @require_torch class EnCodecFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): - feature_extraction_class = EncodecFeatureExtractor if is_speech_available() else None + feature_extraction_class = EncodecFeatureExtractor def setUp(self): self.feat_extract_tester = EnCodecFeatureExtractionTester(self) diff --git a/tests/models/mctct/test_feature_extraction_mctct.py b/tests/models/mctct/test_feature_extraction_mctct.py index f3d8f0fea9..f1825c3640 100644 --- a/tests/models/mctct/test_feature_extraction_mctct.py +++ b/tests/models/mctct/test_feature_extraction_mctct.py @@ -20,15 +20,12 @@ import unittest import numpy as np -from transformers import is_speech_available +from transformers import MCTCTFeatureExtractor from transformers.testing_utils import require_torch from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin -if is_speech_available(): - from transformers import MCTCTFeatureExtractor - global_rng = random.Random() @@ -102,7 +99,7 @@ class MCTCTFeatureExtractionTester(unittest.TestCase): @require_torch class MCTCTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): - feature_extraction_class = MCTCTFeatureExtractor if is_speech_available() else None + feature_extraction_class = MCTCTFeatureExtractor def setUp(self): self.feat_extract_tester = MCTCTFeatureExtractionTester(self) diff --git a/tests/models/speecht5/test_feature_extraction_speecht5.py b/tests/models/speecht5/test_feature_extraction_speecht5.py index a09bf7f8ae..038e6117f7 100644 --- a/tests/models/speecht5/test_feature_extraction_speecht5.py +++ b/tests/models/speecht5/test_feature_extraction_speecht5.py @@ -20,16 +20,13 @@ import unittest import numpy as np -from transformers import BatchFeature, is_speech_available +from transformers import BatchFeature, SpeechT5FeatureExtractor from transformers.testing_utils import require_torch from transformers.utils.import_utils import is_torch_available from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin -if is_speech_available(): - from transformers import SpeechT5FeatureExtractor - if is_torch_available(): import torch @@ -142,7 +139,7 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase): @require_torch class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): - feature_extraction_class = SpeechT5FeatureExtractor if is_speech_available() else None + feature_extraction_class = SpeechT5FeatureExtractor def setUp(self): self.feat_extract_tester = SpeechT5FeatureExtractionTester(self) diff --git a/tests/models/tvlt/test_feature_extraction_tvlt.py b/tests/models/tvlt/test_feature_extraction_tvlt.py index 051708a306..ad37b1f984 100644 --- a/tests/models/tvlt/test_feature_extraction_tvlt.py +++ b/tests/models/tvlt/test_feature_extraction_tvlt.py @@ -22,7 +22,7 @@ import unittest import numpy as np -from transformers import is_datasets_available, is_speech_available +from transformers import TvltFeatureExtractor, is_datasets_available from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio from transformers.utils.import_utils import is_torch_available @@ -35,9 +35,6 @@ if is_torch_available(): if is_datasets_available(): from datasets import load_dataset -if is_speech_available(): - from transformers import TvltFeatureExtractor - global_rng = random.Random() @@ -111,7 +108,7 @@ class TvltFeatureExtractionTester(unittest.TestCase): @require_torch @require_torchaudio class TvltFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): - feature_extraction_class = TvltFeatureExtractor if is_speech_available() else None + feature_extraction_class = TvltFeatureExtractor def setUp(self): self.feat_extract_tester = TvltFeatureExtractionTester(self)