From 6f4424bb086d3d090855862be5aff64eb8ed7101 Mon Sep 17 00:00:00 2001 From: Omar Sanseviero Date: Fri, 18 Aug 2023 22:01:35 +0200 Subject: [PATCH] Make TTS automodels importable (#25595) * Add auto model for spectrogram/waveform * Add doc and install * Add dummy objects * Did I miss anything? --- docs/source/en/model_doc/auto.md | 8 ++++++++ src/transformers/__init__.py | 4 ++++ src/transformers/models/auto/__init__.py | 4 ++++ src/transformers/utils/dummy_pt_objects.py | 14 ++++++++++++++ utils/update_metadata.py | 1 + 5 files changed, 31 insertions(+) diff --git a/docs/source/en/model_doc/auto.md b/docs/source/en/model_doc/auto.md index f493e208ee..9390b96fc5 100644 --- a/docs/source/en/model_doc/auto.md +++ b/docs/source/en/model_doc/auto.md @@ -330,6 +330,14 @@ The following auto classes are available for the following audio tasks. [[autodoc]] AutoModelForAudioXVector +### AutoModelForTextToSpectrogram + +[[autodoc]] AutoModelForTextToSpectrogram + +### AutoModelForTextToWaveform + +[[autodoc]] AutoModelForTextToWaveform + ## Multimodal The following auto classes are available for the following multimodal tasks. diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 0a9bc3257d..9fc2d41bc1 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1135,6 +1135,8 @@ else: "AutoModelForSpeechSeq2Seq", "AutoModelForTableQuestionAnswering", "AutoModelForTextEncoding", + "AutoModelForTextToSpectrogram", + "AutoModelForTextToWaveform", "AutoModelForTokenClassification", "AutoModelForUniversalSegmentation", "AutoModelForVideoClassification", @@ -5050,6 +5052,8 @@ if TYPE_CHECKING: AutoModelForSpeechSeq2Seq, AutoModelForTableQuestionAnswering, AutoModelForTextEncoding, + AutoModelForTextToSpectrogram, + AutoModelForTextToWaveform, AutoModelForTokenClassification, AutoModelForUniversalSegmentation, AutoModelForVideoClassification, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 3a5095c217..12d79822fd 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -101,6 +101,8 @@ else: "AutoModelForSequenceClassification", "AutoModelForSpeechSeq2Seq", "AutoModelForTableQuestionAnswering", + "AutoModelForTextToSpectrogram", + "AutoModelForTextToWaveform", "AutoModelForTokenClassification", "AutoModelForUniversalSegmentation", "AutoModelForVideoClassification", @@ -280,6 +282,8 @@ if TYPE_CHECKING: AutoModelForSpeechSeq2Seq, AutoModelForTableQuestionAnswering, AutoModelForTextEncoding, + AutoModelForTextToSpectrogram, + AutoModelForTextToWaveform, AutoModelForTokenClassification, AutoModelForUniversalSegmentation, AutoModelForVideoClassification, diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 1e8baba71c..b399568814 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -742,6 +742,20 @@ class AutoModelForTextEncoding(metaclass=DummyObject): requires_backends(self, ["torch"]) +class AutoModelForTextToSpectrogram(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class AutoModelForTextToWaveform(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class AutoModelForTokenClassification(metaclass=DummyObject): _backends = ["torch"] diff --git a/utils/update_metadata.py b/utils/update_metadata.py index c3aec0a868..3cda173e5e 100644 --- a/utils/update_metadata.py +++ b/utils/update_metadata.py @@ -115,6 +115,7 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [ ("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"), ("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"), ("mask-generation", "MODEL_FOR_MASK_GENERATION_MAPPING_NAMES", "AutoModelForMaskGeneration"), + ("text-to-audio", "MODEL_FOR_TEXT_TO_SPECTROGRAM_NAMES", "AutoModelForTextToSpectrogram"), ("text-to-audio", "MODEL_FOR_TEXT_TO_WAVEFORM_NAMES", "AutoModelForTextToWaveform"), ]